Skip to content

Commit

Permalink
[Kernel] Support Fp8 Checkpoints (Dynamic + Static) (vllm-project#4332)
Browse files Browse the repository at this point in the history
Co-authored-by: Philipp Moritz <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: mgoin <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
  • Loading branch information
6 people authored Apr 30, 2024
1 parent 8d02dd4 commit 6f1b5c9
Show file tree
Hide file tree
Showing 3 changed files with 307 additions and 40 deletions.
90 changes: 90 additions & 0 deletions tests/models/test_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# flake8: noqa
"""Tests fp8 models against ground truth generation
Note: these tests will only pass on L4 GPU.
"""
import os

import pytest
import torch
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS

os.environ["TOKENIZERS_PARALLELISM"] = "true"

MAX_MODEL_LEN = 1024

MODELS = [
"nm-testing/Meta-Llama-3-8B-Instruct-FP8",
"meta-llama/Meta-Llama-3-8B-Instruct",
]

EXPECTED_STRS_MAP = {
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": [
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**'
],
"meta-llama/Meta-Llama-3-8B-Instruct": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
],
}

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
fp8_not_supported = (capability <
QUANTIZATION_METHODS["fp8"].get_min_capability())


@pytest.mark.skipif(fp8_not_supported,
reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_name", MODELS)
def test_models(
example_prompts,
model_name,
) -> None:
model = LLM(model=model_name,
max_model_len=MAX_MODEL_LEN,
enforce_eager=True,
quantization="fp8")

tokenizer = AutoTokenizer.from_pretrained(model_name)
formatted_prompts = [
tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
tokenize=False,
add_generation_prompt=True)
for prompt in example_prompts
]

params = SamplingParams(max_tokens=20, temperature=0)
generations = []
# Note: these need to be run 1 at a time due to numerical precision,
# since the expected strs were generated this way.
for prompt in formatted_prompts:
outputs = model.generate(prompt, params)
generations.append(outputs[0].outputs[0].text)
del model

print(generations)
expected_strs = EXPECTED_STRS_MAP[model_name]
for i in range(len(example_prompts)):
generated_str = generations[i]
expected_str = expected_strs[i]
assert expected_str == generated_str, (
f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}")
58 changes: 48 additions & 10 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ def __init__(
self.register_parameter("bias", None)

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)

tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
param_data = param.data
Expand All @@ -254,6 +258,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
loaded_weight,
shard_id=0)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

Expand Down Expand Up @@ -317,7 +327,12 @@ def weight_loader(self,

param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)

if loaded_shard_id is None:
# Loaded weight is already packed.
if output_dim is None:
Expand All @@ -331,14 +346,13 @@ def weight_loader(self,
current_shard_offset += output_size
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor

# If marlin, we need to adjust the offset and size to
# account for the tiling.
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

Expand All @@ -353,15 +367,14 @@ def weight_loader(self,
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor

# If marlin, we need to adjust the offset and size to
# account for the tiling.
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

Expand All @@ -370,11 +383,17 @@ def weight_loader(self,
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)

else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
Expand Down Expand Up @@ -455,7 +474,11 @@ def weight_loader(self,
loaded_shard_id: Optional[str] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)

if loaded_shard_id is None:
# Loaded weight is already packed.
Expand All @@ -473,14 +496,14 @@ def weight_loader(self,
]
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor

# If marlin, we need to adjust the offset and size to
# account for the tiling.
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

Expand All @@ -502,15 +525,15 @@ def weight_loader(self,
shard_offset = (self.num_heads +
self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor

# If marlin, we need to adjust the offset and size to
# account for the tiling.
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

Expand All @@ -523,12 +546,17 @@ def weight_loader(self,
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size,
shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
Expand Down Expand Up @@ -611,6 +639,10 @@ def __init__(
self.register_parameter("bias", None)

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)

tp_rank = get_tensor_model_parallel_rank()
input_dim = getattr(param, "input_dim", None)
param_data = param.data
Expand All @@ -619,6 +651,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
loaded_weight,
shard_id=0)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

Expand Down
Loading

0 comments on commit 6f1b5c9

Please sign in to comment.