Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Support Fp8 Checkpoints (Dynamic + Static) #4332

Merged
merged 93 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
79c94a1
fixed fp8 conflict with aqlm
robertgshaw2-neuralmagic Apr 23, 2024
f8b57e4
added quantization tests to buildkite
robertgshaw2-neuralmagic Apr 23, 2024
7175e5b
removed commented out piece
robertgshaw2-neuralmagic Apr 23, 2024
7a7520d
model loaded!
robertgshaw2-neuralmagic Apr 23, 2024
e0b4d72
renamed
robertgshaw2-neuralmagic Apr 23, 2024
f96428e
stash
robertgshaw2-neuralmagic Apr 24, 2024
88ba83b
added static fp8
robertgshaw2-neuralmagic Apr 24, 2024
0848b25
to try with torch.scaled_mm
robertgshaw2-neuralmagic Apr 24, 2024
15882ea
stash
robertgshaw2-neuralmagic Apr 24, 2024
7e6b675
added way to do weight quantization
robertgshaw2-neuralmagic Apr 24, 2024
cc959ea
working!
robertgshaw2-neuralmagic Apr 24, 2024
8d68dbc
fixed llama
robertgshaw2-neuralmagic Apr 24, 2024
881fc65
fixed llama again
robertgshaw2-neuralmagic Apr 24, 2024
e6dd46f
updated names
robertgshaw2-neuralmagic Apr 24, 2024
7e3933b
nit
robertgshaw2-neuralmagic Apr 24, 2024
453a236
cleanup
robertgshaw2-neuralmagic Apr 24, 2024
310e0a7
cleanup
robertgshaw2-neuralmagic Apr 24, 2024
ab4cb02
missed file :)
robertgshaw2-neuralmagic Apr 24, 2024
2edd93a
Update fp8.py
robertgshaw2-neuralmagic Apr 24, 2024
ccee5d3
Implement static scaling for Mixtral
pcmoritz Apr 24, 2024
8f71c79
fix
pcmoritz Apr 24, 2024
6eb01e0
update
pcmoritz Apr 24, 2024
dc89cbc
fix
pcmoritz Apr 24, 2024
be60845
update
pcmoritz Apr 24, 2024
4613cb5
update
pcmoritz Apr 24, 2024
3d95d86
fix
pcmoritz Apr 24, 2024
642763f
move
pcmoritz Apr 24, 2024
706e931
update
pcmoritz Apr 24, 2024
9a3c78c
lol
pcmoritz Apr 24, 2024
1b6f020
fix cuda graph
pcmoritz Apr 24, 2024
b09bcec
fix
pcmoritz Apr 24, 2024
052e2b3
update
pcmoritz Apr 24, 2024
b33c6d7
update
pcmoritz Apr 25, 2024
475f58d
refactor
pcmoritz Apr 25, 2024
56b4880
update
pcmoritz Apr 25, 2024
be37154
revert
pcmoritz Apr 25, 2024
9c54d19
format
pcmoritz Apr 25, 2024
c5155ea
Update vllm/_custom_ops.py
pcmoritz Apr 25, 2024
948cca7
Update vllm/model_executor/layers/fused_moe/fused_moe.py
pcmoritz Apr 25, 2024
3feb887
Update vllm/model_executor/models/mixtral.py
pcmoritz Apr 25, 2024
df16316
format
pcmoritz Apr 25, 2024
7b6b0fa
support static scales
robertgshaw2-neuralmagic Apr 25, 2024
1a3b2e1
fixed example
robertgshaw2-neuralmagic Apr 25, 2024
63ad2ef
Delete quantize.ipynb
robertgshaw2-neuralmagic Apr 25, 2024
794f1a1
Update vllm/_custom_ops.py
pcmoritz Apr 25, 2024
c13b6a4
update
pcmoritz Apr 25, 2024
5a230ed
update
pcmoritz Apr 25, 2024
80069c9
format
pcmoritz Apr 25, 2024
5ce17d0
activation_scale -> act_scale
pcmoritz Apr 25, 2024
5fc0335
Update scheme->activation_scheme
mgoin Apr 25, 2024
92d5162
fix dynamic scaling -- need init to zero due to atomic update
pcmoritz Apr 25, 2024
e1bfe10
Format
mgoin Apr 25, 2024
7242600
Fix tuple type
mgoin Apr 25, 2024
8512513
Merge remote-tracking branch 'pcmoritz/mixtral-fp8-static' into fp8-s…
robertgshaw2-neuralmagic Apr 26, 2024
21ddbb4
stash tyler's state
robertgshaw2-neuralmagic Apr 26, 2024
d27015c
stash
robertgshaw2-neuralmagic Apr 26, 2024
1111f87
cutlass working, but slow jitting on hotpath
robertgshaw2-neuralmagic Apr 26, 2024
f5d32ae
first end to end run with mixtral
robertgshaw2-neuralmagic Apr 26, 2024
924e8ce
added missed file
robertgshaw2-neuralmagic Apr 26, 2024
823a2e7
Update run_fp8.py
mgoin Apr 26, 2024
81f42be
Dynamic FP8 works, but static does not (#213)
robertgshaw2-neuralmagic Apr 27, 2024
1a4fd8a
static correctness
robertgshaw2-neuralmagic Apr 27, 2024
e48c981
static fp8 loading
robertgshaw2-neuralmagic Apr 27, 2024
02f683e
working for dense models
robertgshaw2-neuralmagic Apr 27, 2024
81b73ef
Update weight_utils.py
robertgshaw2-neuralmagic Apr 27, 2024
58dbe0f
moving mixtral updates to separate pr
robertgshaw2-neuralmagic Apr 27, 2024
6068dc5
Merge branch 'main' into fp8-static
robertgshaw2-neuralmagic Apr 27, 2024
a8d4b33
make ./format pass
robertgshaw2-neuralmagic Apr 27, 2024
5be0970
better comments in linear.py
robertgshaw2-neuralmagic Apr 27, 2024
ef7992b
better comments in linear.py
robertgshaw2-neuralmagic Apr 27, 2024
0667791
fixed opt-125
robertgshaw2-neuralmagic Apr 27, 2024
d8adf14
removed run_fp8.py
robertgshaw2-neuralmagic Apr 27, 2024
9bb1a2b
format
robertgshaw2-neuralmagic Apr 27, 2024
169c9ed
Cleanup opt.py
mgoin Apr 27, 2024
8ef9c7d
added testing
robertgshaw2-neuralmagic Apr 27, 2024
c7d6dd6
./format.sh
robertgshaw2-neuralmagic Apr 27, 2024
50b5823
fixed typing
robertgshaw2-neuralmagic Apr 27, 2024
4156ca9
fixed typing
robertgshaw2-neuralmagic Apr 27, 2024
3148fc9
added warning format
robertgshaw2-neuralmagic Apr 27, 2024
7846d67
Update opt.py
robertgshaw2-neuralmagic Apr 27, 2024
ba408c6
formatted
robertgshaw2-neuralmagic Apr 27, 2024
04617fd
Update vllm/model_executor/layers/quantization/fp8.py
robertgshaw2-neuralmagic Apr 27, 2024
cc3d395
Update vllm/model_executor/layers/quantization/fp8.py
robertgshaw2-neuralmagic Apr 27, 2024
f556016
auto detect shared scale (#214)
robertgshaw2-neuralmagic Apr 28, 2024
30bfbd8
./format.sh
robertgshaw2-neuralmagic Apr 28, 2024
572107a
Update vllm/model_executor/layers/quantization/fp8.py
robertgshaw2-neuralmagic Apr 29, 2024
41fbde9
./format.sh
robertgshaw2-neuralmagic Apr 29, 2024
f2cd561
addressed cody's comments + format
robertgshaw2-neuralmagic Apr 29, 2024
125266e
make mypy happy
robertgshaw2-neuralmagic Apr 29, 2024
8a566a7
Merge remote-tracking branch 'upstream/main' into fp8-static
robertgshaw2-neuralmagic Apr 30, 2024
280a4d5
test
robertgshaw2-neuralmagic Apr 30, 2024
8e1ede1
cleaned up
robertgshaw2-neuralmagic Apr 30, 2024
d067428
Update vllm/model_executor/layers/quantization/fp8.py
robertgshaw2-neuralmagic Apr 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,12 @@ void gptq_shuffle(
torch::Tensor q_perm,
int bit);

void scaled_fp8_quant(
void static_scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);

void dynamic_scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);
Expand Down
3 changes: 2 additions & 1 deletion csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
Expand Down
25 changes: 24 additions & 1 deletion csrc/quantization/fp8/fp8_cuda_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel(

} // namespace vllm

void scaled_fp8_quant(
void static_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"scaled_fp8_quant_kernel",
[&] {
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(),
input.data_ptr<scalar_t>(),
scale.data_ptr<float>(),
num_elems);
});
}

void dynamic_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
Expand Down
1 change: 1 addition & 0 deletions requirements-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ nvidia-ml-py # for pynvml package
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.2.1
xformers == 0.0.25 # Requires PyTorch 2.2.1
nvidia-cutlass
41 changes: 41 additions & 0 deletions run_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import argparse

from transformers import AutoTokenizer

from vllm import LLM

choices = ["llama-static", "mistral-static", "mistral-dynamic", "mixtral-static"]

parser = argparse.ArgumentParser()
parser.add_argument("--type", choices="mixtral-static")
mgoin marked this conversation as resolved.
Show resolved Hide resolved

if __name__ == "__main__":
args = parser.parse_args()

if args.type == "llama-static":
model_name = "nm-testing/Meta-Llama-3-8B-Instruct-FP8"
elif args.type == "mistral-static":
model_name = "nm-testing/mistral-fp8-static"
elif args.type == "mistral-dynamic":
model_name = "nm-testing/mistral-fp8-dynamic"
elif args.type == 'mixtral-static':
model_name = "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8"
else:
raise ValueError(f"--type should be in {choices}")

model = LLM(model_name,
enforce_eager=True,
max_model_len=1024)

tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = tokenizer.apply_chat_template([{
"role": "user",
"content": "What is your name"
}], tokenize=False, add_generation_prompt=True)
print(f"----- Prompt: {prompt}")

outputs = model.generate(prompt)
print(outputs)
generation = outputs[0].outputs[0].text
print(f"----- Generation: {generation}")
12 changes: 9 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,16 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,


# fp8
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
vllm_ops.scaled_fp8_quant(output, input, scale)
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
else:
vllm_ops.static_scaled_fp8_quant(output, input, scale)
return output, scale


Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/layers/fused_gemm_dq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from vllm.model_executor.layers.fused_gemm_dq.fused_gemm_dq_fp8 import fused_gemm_dq_fp8

__all__ = [
"fused_gemm_dq_fp8",
]
87 changes: 87 additions & 0 deletions vllm/model_executor/layers/fused_gemm_dq/fused_gemm_dq_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import cutlass
from cutlass import Tensor as FakeTensor
import cutlass.epilogue

import torch
from typing import Optional, Tuple, Dict


def setup_dequant_epilogue(
plan: cutlass.op.Gemm,
dq: torch.Tensor,
scale_a: Optional[torch.Tensor],
scale_b: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
) -> Tuple[cutlass.op.Gemm, Dict]:
assert bias is None

if all([scale_a is None, scale_b is None]):
return plan, None
assert scale_b is not None

def epilog_with_scale_b(accum, scale_b):
D = scale_b * accum
return D

def epilog_with_both_scales(accum, scale_a, scale_b):
D = scale_a * (scale_b * accum)
return D

visitor_args = {"scale_a": scale_a, "scale_b": scale_b, "D": dq}
epilogue_tensors = {
"accum": FakeTensor(
element=torch.float32,
shape=dq.shape,
layout_tag=cutlass.LayoutType.RowMajor,
),
"D": dq,
"scale_b": scale_b,
}
epilog_fn = epilog_with_scale_b

if scale_a is not None:
epilogue_tensors["scale_a"] = scale_a
visitor_args["scale_a"] = scale_a
epilog_fn = epilog_with_both_scales

plan.epilogue_visitor = cutlass.epilogue.trace(epilog_fn, epilogue_tensors)
return plan, visitor_args


def fused_gemm_dq_fp8(
x_q: torch.Tensor,
w_q: torch.Tensor,
out_dtype: torch.dtype,
scale_a: Optional[torch.Tensor] = None,
scale_b: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
dq = torch.empty((x_q.shape[0], w_q.shape[1]), dtype=out_dtype, device="cuda")
C = torch.zeros((x_q.shape[0], w_q.shape[1]), dtype=out_dtype, device="cuda")

plan = cutlass.op.Gemm(
element_A=x_q.dtype,
element_B=w_q.dtype,
element_C=dq.dtype,
element_D=dq.dtype,
layout_A=cutlass.LayoutType.RowMajor,
layout_B=cutlass.LayoutType.ColumnMajor,
layout_C=cutlass.LayoutType.RowMajor,
element_accumulator=torch.float32,
kernel_cc=90,
)

plan, visitor_args = setup_dequant_epilogue(plan, dq, scale_a, scale_b, bias)

plan.run(
x_q,
w_q,
C,
dq,
alpha=1,
beta=0,
visitor_args=visitor_args,
print_module=False,
)

return dq
13 changes: 9 additions & 4 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,9 @@ def moe_align_block_size(


def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
B_scale: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
Expand All @@ -232,10 +233,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
assert sorted_token_ids.stride(0) == 1

if not use_fp8:
A_scale = None
assert A_scale is None
assert B_scale is None
else:
A, A_scale = ops.scaled_fp8_quant(A)
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None

grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
Expand Down Expand Up @@ -318,6 +319,8 @@ def fused_moe(
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
Expand Down Expand Up @@ -434,6 +437,7 @@ def fused_moe(
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
Expand All @@ -451,6 +455,7 @@ def fused_moe(
invoke_fused_moe_kernel(intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
Expand Down
48 changes: 48 additions & 0 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,21 @@ def weight_loader(self,
param_data = param.data
output_dim = getattr(param, "output_dim", None)
is_metadata = getattr(param, "is_metadata", False)

# TODO: document.
# TODO: sync with is_metadata.
# For loading scales.
shard_indexer = getattr(param, "shard_indexer", None)
logical_widths = getattr(param, "logical_widths", None)
if output_dim is not None and shard_indexer is not None:
raise NotImplementedError(
"We do not currently support output_dim != None and "
"shard_indexer != None for a parameter. Please open an issue.")
if loaded_shard_id is None and shard_indexer is not None:
raise NotImplementedError(
"We do not currently support loaded_shard_id == None and "
"shard_indexer != None for a parameter. Please open an issue.")

if loaded_shard_id is None:
# Loaded weight is already packed.
if output_dim is None:
Expand Down Expand Up @@ -352,6 +367,15 @@ def weight_loader(self,
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)

# TODO: sync with is_metadata UX.
# If a param_shard_splitter is defined by the LinearMethod, use it.
elif shard_indexer is not None:
param_data, loaded_weight = shard_indexer(param_data,
loaded_weight,
loaded_shard_id,
logical_widths)

else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
Expand Down Expand Up @@ -434,6 +458,18 @@ def weight_loader(self,
output_dim = getattr(param, "output_dim", None)
is_metadata = getattr(param, "is_metadata", False)

# TODO: sync with is_metadata UX
shard_indexer = getattr(param, "shard_indexer", None)
logical_widths = getattr(param, "logical_widths", None)
if output_dim is not None and shard_indexer is not None:
raise NotImplementedError(
"We do not currently support output_dim != None and "
"shard_indexer != None for a parameter. Please open an issue.")
if loaded_shard_id is None and shard_indexer is not None:
raise NotImplementedError(
"We do not currently support loaded_shard_id == None and "
"shard_indexer != None for a parameter. Please open an issue.")

if loaded_shard_id is None:
# Loaded weight is already packed.
if output_dim is None:
Expand Down Expand Up @@ -506,6 +542,13 @@ def weight_loader(self,
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size,
shard_size)
# TODO: sync with QKV
# If a param_shard_splitter is defined by the LinearMethod, use it.
elif shard_indexer is not None:
param_data, loaded_weight = shard_indexer(param_data,
loaded_weight,
loaded_shard_id,
logical_widths)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
Expand Down Expand Up @@ -602,6 +645,11 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
# TODO: canon
# This is for loading scales for fp8, which have no dims.
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import FP8Config
# from vllm.model_executor.layers.quantization.fp8 import FP8Config
from vllm.model_executor.layers.quantization.fp8_serialized import FP8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def get_min_capability(self) -> int:
"""
raise NotImplementedError

# The following is not an abstract method and returns True by default.
@classmethod
def require_config_file(cls) -> bool:
"""Whether this quantization config needs a configuration filen."""
return True

@staticmethod
@abstractmethod
def get_config_filenames() -> List[str]:
Expand Down
Loading
Loading