Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Support Fp8 Checkpoints (Dynamic + Static) #4332

Merged
merged 93 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 91 commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
79c94a1
fixed fp8 conflict with aqlm
robertgshaw2-neuralmagic Apr 23, 2024
f8b57e4
added quantization tests to buildkite
robertgshaw2-neuralmagic Apr 23, 2024
7175e5b
removed commented out piece
robertgshaw2-neuralmagic Apr 23, 2024
7a7520d
model loaded!
robertgshaw2-neuralmagic Apr 23, 2024
e0b4d72
renamed
robertgshaw2-neuralmagic Apr 23, 2024
f96428e
stash
robertgshaw2-neuralmagic Apr 24, 2024
88ba83b
added static fp8
robertgshaw2-neuralmagic Apr 24, 2024
0848b25
to try with torch.scaled_mm
robertgshaw2-neuralmagic Apr 24, 2024
15882ea
stash
robertgshaw2-neuralmagic Apr 24, 2024
7e6b675
added way to do weight quantization
robertgshaw2-neuralmagic Apr 24, 2024
cc959ea
working!
robertgshaw2-neuralmagic Apr 24, 2024
8d68dbc
fixed llama
robertgshaw2-neuralmagic Apr 24, 2024
881fc65
fixed llama again
robertgshaw2-neuralmagic Apr 24, 2024
e6dd46f
updated names
robertgshaw2-neuralmagic Apr 24, 2024
7e3933b
nit
robertgshaw2-neuralmagic Apr 24, 2024
453a236
cleanup
robertgshaw2-neuralmagic Apr 24, 2024
310e0a7
cleanup
robertgshaw2-neuralmagic Apr 24, 2024
ab4cb02
missed file :)
robertgshaw2-neuralmagic Apr 24, 2024
2edd93a
Update fp8.py
robertgshaw2-neuralmagic Apr 24, 2024
ccee5d3
Implement static scaling for Mixtral
pcmoritz Apr 24, 2024
8f71c79
fix
pcmoritz Apr 24, 2024
6eb01e0
update
pcmoritz Apr 24, 2024
dc89cbc
fix
pcmoritz Apr 24, 2024
be60845
update
pcmoritz Apr 24, 2024
4613cb5
update
pcmoritz Apr 24, 2024
3d95d86
fix
pcmoritz Apr 24, 2024
642763f
move
pcmoritz Apr 24, 2024
706e931
update
pcmoritz Apr 24, 2024
9a3c78c
lol
pcmoritz Apr 24, 2024
1b6f020
fix cuda graph
pcmoritz Apr 24, 2024
b09bcec
fix
pcmoritz Apr 24, 2024
052e2b3
update
pcmoritz Apr 24, 2024
b33c6d7
update
pcmoritz Apr 25, 2024
475f58d
refactor
pcmoritz Apr 25, 2024
56b4880
update
pcmoritz Apr 25, 2024
be37154
revert
pcmoritz Apr 25, 2024
9c54d19
format
pcmoritz Apr 25, 2024
c5155ea
Update vllm/_custom_ops.py
pcmoritz Apr 25, 2024
948cca7
Update vllm/model_executor/layers/fused_moe/fused_moe.py
pcmoritz Apr 25, 2024
3feb887
Update vllm/model_executor/models/mixtral.py
pcmoritz Apr 25, 2024
df16316
format
pcmoritz Apr 25, 2024
7b6b0fa
support static scales
robertgshaw2-neuralmagic Apr 25, 2024
1a3b2e1
fixed example
robertgshaw2-neuralmagic Apr 25, 2024
63ad2ef
Delete quantize.ipynb
robertgshaw2-neuralmagic Apr 25, 2024
794f1a1
Update vllm/_custom_ops.py
pcmoritz Apr 25, 2024
c13b6a4
update
pcmoritz Apr 25, 2024
5a230ed
update
pcmoritz Apr 25, 2024
80069c9
format
pcmoritz Apr 25, 2024
5ce17d0
activation_scale -> act_scale
pcmoritz Apr 25, 2024
5fc0335
Update scheme->activation_scheme
mgoin Apr 25, 2024
92d5162
fix dynamic scaling -- need init to zero due to atomic update
pcmoritz Apr 25, 2024
e1bfe10
Format
mgoin Apr 25, 2024
7242600
Fix tuple type
mgoin Apr 25, 2024
8512513
Merge remote-tracking branch 'pcmoritz/mixtral-fp8-static' into fp8-s…
robertgshaw2-neuralmagic Apr 26, 2024
21ddbb4
stash tyler's state
robertgshaw2-neuralmagic Apr 26, 2024
d27015c
stash
robertgshaw2-neuralmagic Apr 26, 2024
1111f87
cutlass working, but slow jitting on hotpath
robertgshaw2-neuralmagic Apr 26, 2024
f5d32ae
first end to end run with mixtral
robertgshaw2-neuralmagic Apr 26, 2024
924e8ce
added missed file
robertgshaw2-neuralmagic Apr 26, 2024
823a2e7
Update run_fp8.py
mgoin Apr 26, 2024
81f42be
Dynamic FP8 works, but static does not (#213)
robertgshaw2-neuralmagic Apr 27, 2024
1a4fd8a
static correctness
robertgshaw2-neuralmagic Apr 27, 2024
e48c981
static fp8 loading
robertgshaw2-neuralmagic Apr 27, 2024
02f683e
working for dense models
robertgshaw2-neuralmagic Apr 27, 2024
81b73ef
Update weight_utils.py
robertgshaw2-neuralmagic Apr 27, 2024
58dbe0f
moving mixtral updates to separate pr
robertgshaw2-neuralmagic Apr 27, 2024
6068dc5
Merge branch 'main' into fp8-static
robertgshaw2-neuralmagic Apr 27, 2024
a8d4b33
make ./format pass
robertgshaw2-neuralmagic Apr 27, 2024
5be0970
better comments in linear.py
robertgshaw2-neuralmagic Apr 27, 2024
ef7992b
better comments in linear.py
robertgshaw2-neuralmagic Apr 27, 2024
0667791
fixed opt-125
robertgshaw2-neuralmagic Apr 27, 2024
d8adf14
removed run_fp8.py
robertgshaw2-neuralmagic Apr 27, 2024
9bb1a2b
format
robertgshaw2-neuralmagic Apr 27, 2024
169c9ed
Cleanup opt.py
mgoin Apr 27, 2024
8ef9c7d
added testing
robertgshaw2-neuralmagic Apr 27, 2024
c7d6dd6
./format.sh
robertgshaw2-neuralmagic Apr 27, 2024
50b5823
fixed typing
robertgshaw2-neuralmagic Apr 27, 2024
4156ca9
fixed typing
robertgshaw2-neuralmagic Apr 27, 2024
3148fc9
added warning format
robertgshaw2-neuralmagic Apr 27, 2024
7846d67
Update opt.py
robertgshaw2-neuralmagic Apr 27, 2024
ba408c6
formatted
robertgshaw2-neuralmagic Apr 27, 2024
04617fd
Update vllm/model_executor/layers/quantization/fp8.py
robertgshaw2-neuralmagic Apr 27, 2024
cc3d395
Update vllm/model_executor/layers/quantization/fp8.py
robertgshaw2-neuralmagic Apr 27, 2024
f556016
auto detect shared scale (#214)
robertgshaw2-neuralmagic Apr 28, 2024
30bfbd8
./format.sh
robertgshaw2-neuralmagic Apr 28, 2024
572107a
Update vllm/model_executor/layers/quantization/fp8.py
robertgshaw2-neuralmagic Apr 29, 2024
41fbde9
./format.sh
robertgshaw2-neuralmagic Apr 29, 2024
f2cd561
addressed cody's comments + format
robertgshaw2-neuralmagic Apr 29, 2024
125266e
make mypy happy
robertgshaw2-neuralmagic Apr 29, 2024
8a566a7
Merge remote-tracking branch 'upstream/main' into fp8-static
robertgshaw2-neuralmagic Apr 30, 2024
280a4d5
test
robertgshaw2-neuralmagic Apr 30, 2024
8e1ede1
cleaned up
robertgshaw2-neuralmagic Apr 30, 2024
d067428
Update vllm/model_executor/layers/quantization/fp8.py
robertgshaw2-neuralmagic Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions tests/models/test_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# flake8: noqa
"""Tests fp8 models against ground truth generation
Note: these tests will only pass on L4 GPU.
"""
import os

import pytest
import torch
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS

os.environ["TOKENIZERS_PARALLELISM"] = "true"

MAX_MODEL_LEN = 1024

MODELS = [
"nm-testing/mistral-fp8-static",
"nm-testing/mistral-fp8-dynamic",
"mistralai/Mistral-7B-Instruct-v0.2",
]
robertgshaw2-neuralmagic marked this conversation as resolved.
Show resolved Hide resolved

EXPECTED_STRS_MAP = {
"nm-testing/mistral-fp8-static": [
' VLLM (Vulcan Language Model) is an open-source inference and serving engine',
' 1. 1950s: The Concept of AI is Born: The term',
' Artificial Intelligence (AI) and Human Intelligence (HI) are two distinct ways of processing information.',
' A neural network is a type of machine learning model inspired by the structure and function of the human brain',
' In the heart of the bustling city of Neo-Tokyo, nestled among the tow',
' The COVID-19 pandemic has had a profound impact on global economic structures and has forced businesses to',
' The Mona Lisa painting, created by the Italian artist Leonardo da Vinci between 15',
' Japanese: 早く起きる'
],
"nm-testing/mistral-fp8-dynamic": [
' VLLM (Vulcan Language Model) is an open-source, high-throughput',
' 1. 1950s: The Concept of AI is Born: The term',
' Artificial Intelligence (AI) and Human Intelligence (HI) are two distinct ways of processing information.',
" A neural network is a type of machine learning model inspired by the human brain's structure and function",
' Once upon a time, in the heart of a bustling city, there was a robot named B',
' The COVID-19 pandemic has had a profound impact on global economic structures and has forced businesses to',
' The Mona Lisa painting, created by the Italian artist Leonardo da Vinci between 15',
' Japanese: 早く起きる鳥は虫を取る (S'
],
"mistralai/Mistral-7B-Instruct-v0.2": [
' VLLM (Vulcan Language Model) is an open-source, high-throughput',
' 1. 1950s: The Concept of AI is Born: The term',
' Artificial Intelligence (AI) and Human Intelligence (HI) are two distinct ways of processing information.',
" A neural network is a type of machine learning model inspired by the human brain's structure and function",
' In the heart of the bustling city of Neo-Tokyo, nestled among the tow',
' The COVID-19 pandemic has had a profound impact on global economic structures and has forced businesses to',
' The Mona Lisa painting, created by the Italian artist Leonardo da Vinci between 15',
' Japanese: 早く起きる鳥は虫を取る (S'
],
}

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
fp8_not_supported = (capability <
QUANTIZATION_METHODS["fp8"].get_min_capability())
robertgshaw2-neuralmagic marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.skipif(fp8_not_supported,
reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_name", MODELS)
def test_models(
example_prompts,
model_name,
) -> None:
model = LLM(model=model_name,
max_model_len=MAX_MODEL_LEN,
enforce_eager=True,
quantization="fp8")

tokenizer = AutoTokenizer.from_pretrained(model_name)
formatted_prompts = [
tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
tokenize=False,
add_generation_prompt=True)
for prompt in example_prompts
]

params = SamplingParams(max_tokens=20, temperature=0)
generations = []
# Note: these need to be run 1 at a time due to numerical precision,
# since the expected strs were generated this way.
for prompt in formatted_prompts:
outputs = model.generate(prompt, params)
generations.append(outputs[0].outputs[0].text)
del model

print(generations)
expected_strs = EXPECTED_STRS_MAP[model_name]
for i in range(len(example_prompts)):
generated_str = generations[i]
expected_str = expected_strs[i]
assert expected_str == generated_str, (
f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}")
58 changes: 48 additions & 10 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ def __init__(
self.register_parameter("bias", None)

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)

tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
param_data = param.data
Expand All @@ -254,6 +258,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
loaded_weight,
shard_id=0)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

Expand Down Expand Up @@ -317,7 +327,12 @@ def weight_loader(self,

param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)

if loaded_shard_id is None:
# Loaded weight is already packed.
if output_dim is None:
Expand All @@ -331,14 +346,13 @@ def weight_loader(self,
current_shard_offset += output_size
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor

# If marlin, we need to adjust the offset and size to
# account for the tiling.
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

Expand All @@ -353,15 +367,14 @@ def weight_loader(self,
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor

# If marlin, we need to adjust the offset and size to
# account for the tiling.
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

Expand All @@ -370,11 +383,17 @@ def weight_loader(self,
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)

else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
Expand Down Expand Up @@ -455,7 +474,11 @@ def weight_loader(self,
loaded_shard_id: Optional[str] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)

if loaded_shard_id is None:
# Loaded weight is already packed.
Expand All @@ -473,14 +496,14 @@ def weight_loader(self,
]
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor

# If marlin, we need to adjust the offset and size to
# account for the tiling.
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

Expand All @@ -502,15 +525,15 @@ def weight_loader(self,
shard_offset = (self.num_heads +
self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor

# If marlin, we need to adjust the offset and size to
# account for the tiling.
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

Expand All @@ -523,12 +546,17 @@ def weight_loader(self,
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size,
shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
Expand Down Expand Up @@ -611,6 +639,10 @@ def __init__(
self.register_parameter("bias", None)

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)

tp_rank = get_tensor_model_parallel_rank()
input_dim = getattr(param, "input_dim", None)
param_data = param.data
Expand All @@ -619,6 +651,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
loaded_weight,
shard_id=0)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

Expand Down
Loading
Loading