Skip to content

Commit

Permalink
[Misc] Load FP8 kv-cache scaling factors from checkpoints (vllm-proje…
Browse files Browse the repository at this point in the history
…ct#4893)

The 2nd PR for vllm-project#4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
  • Loading branch information
comaniac authored and joerunde committed Jun 3, 2024
1 parent 56382f4 commit 0fc07da
Show file tree
Hide file tree
Showing 40 changed files with 284 additions and 158 deletions.
14 changes: 6 additions & 8 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument(
"--kv-cache-dtype",
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8'],
default='auto',
help=
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
Expand Down
12 changes: 5 additions & 7 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,15 +323,13 @@ def main(args: argparse.Namespace):
action="store_true",
help="enforce eager execution")
parser.add_argument(
"--kv-cache-dtype",
'--kv-cache-dtype',
type=str,
choices=["auto", "fp8"],
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
Expand Down
10 changes: 4 additions & 6 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,11 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8"],
choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
args = parser.parse_args()
print(args)

Expand Down
80 changes: 52 additions & 28 deletions tests/models/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,55 @@
MAX_MODEL_LEN = 1024

MODELS = [
"nm-testing/Meta-Llama-3-8B-Instruct-FP8",
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV",
"meta-llama/Meta-Llama-3-8B-Instruct",
]

EXPECTED_STRS_MAP = {
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": [
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
'Zeta-5, a highly advanced robot designed for menial labor, whirred to a',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o',
],
"meta-llama/Meta-Llama-3-8B-Instruct": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
],
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": {
"auto": [
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no'
],
"fp8": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system made up of several basic components that work together to enable it to',
'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk'
]
},
"meta-llama/Meta-Llama-3-8B-Instruct": {
"auto": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
],
"fp8": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu'
]
},
}

capability = torch.cuda.get_device_capability()
Expand All @@ -52,14 +76,14 @@
@pytest.mark.skipif(fp8_not_supported,
reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_name", MODELS)
def test_models(
example_prompts,
model_name,
) -> None:
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
def test_models(example_prompts, model_name, kv_cache_dtype) -> None:
model = LLM(model=model_name,
max_model_len=MAX_MODEL_LEN,
trust_remote_code=True,
enforce_eager=True,
quantization="fp8")
quantization="fp8",
kv_cache_dtype=kv_cache_dtype)

tokenizer = AutoTokenizer.from_pretrained(model_name)
formatted_prompts = [
Expand All @@ -81,8 +105,8 @@ def test_models(
generations.append(outputs[0].outputs[0].text)
del model

print(generations)
expected_strs = EXPECTED_STRS_MAP[model_name]
print(model_name, kv_cache_dtype, generations)
expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype]
for i in range(len(example_prompts)):
generated_str = generations[i]
expected_str = expected_strs[i]
Expand Down
27 changes: 25 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)


class Attention(nn.Module):
Expand All @@ -30,6 +32,7 @@ def __init__(
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
if cache_config is not None:
Expand All @@ -40,6 +43,27 @@ def __init__(
block_size = 16
if num_kv_heads is None:
num_kv_heads = num_heads

# The default kv_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized kv_scale to be loaded along
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self._kv_scale = 1.0
quant_method = quant_config.get_quant_method(
self) if quant_config else None
if quant_method is not None:
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# When FP8 quantization is enabled, we make a parameter
# "kv_scale" so that it can be loaded from FP8 checkpoint.
# The kv_scale will then be converted back
# to self._kv_scale in a native float32 value after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)

# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
Expand All @@ -57,10 +81,9 @@ def forward(
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
kv_scale)
self._kv_scale)

def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
Expand Down
8 changes: 3 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,14 +355,12 @@ def _verify_args(self) -> None:
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype == "fp8":
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"But it may cause slight accuracy drop without scaling "
"factors. FP8_E5M2 (without scaling) is only supported on "
"cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
"is instead supported for common inference criteria.")
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

Expand Down
7 changes: 3 additions & 4 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,11 @@ def add_cli_args(
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8'],
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. FP8_E5M2 (without scaling) is only supported on cuda '
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria.')
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=nullable_str,
Expand Down
47 changes: 45 additions & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once

ACTIVATION_SCHEMES = ["static", "dynamic"]

Expand Down Expand Up @@ -58,9 +59,13 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
activation_scheme=activation_scheme)

def get_quant_method(
self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]:
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import

if isinstance(layer, LinearBase):
return Fp8LinearMethod(self)
if isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
Expand Down Expand Up @@ -251,6 +256,44 @@ def apply(self,
return torch.narrow(output, 0, 0, x.shape[0])


class Fp8KVCacheMethod(QuantizeMethodBase):
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
"""

def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config

def create_weights(self, layer: torch.nn.Module):
"""Create "weight" (aka kv_scale) for an attention layer.
Args:
layer: The layer that is using the QuantizeMethodBase factory.
"""
# Initialize the KV cache scale to 1.0 as the default value.
# If the kv_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)

def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")

def process_weights_after_loading(self, layer: Module) -> None:
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto":
kv_scale = layer.kv_scale.to("cpu").tolist()
if not isinstance(kv_scale, float):
raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache")
layer._kv_scale = kv_scale
if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
"cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint.")
del layer.kv_scale


def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ def __init__(
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)

def forward(
self,
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def __init__(
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
alibi_slopes=alibi_slopes,
quant_config=quant_config)
else:
self.rotary_emb = get_rope(
self.head_dim,
Expand All @@ -166,7 +167,8 @@ def __init__(
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)

def forward(
self,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def __init__(
self.head_dim,
scaling,
alibi_slopes=alibi_slopes,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)

def forward(
self,
Expand Down
13 changes: 6 additions & 7 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,12 @@ def __init__(
base=10000 * rope_ratio,
is_neox_style=False,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)

def forward(
self,
Expand Down
Loading

0 comments on commit 0fc07da

Please sign in to comment.