Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Optimize FP8 support for MoE kernel / Mixtral via static scales #4343

Merged
merged 33 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,12 @@ void gptq_shuffle(
torch::Tensor q_perm,
int bit);

void scaled_fp8_quant(
void static_scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);

void dynamic_scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);
Expand Down
3 changes: 2 additions & 1 deletion csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
Expand Down
25 changes: 24 additions & 1 deletion csrc/quantization/fp8/fp8_cuda_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel(

} // namespace vllm

void scaled_fp8_quant(
void static_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"scaled_fp8_quant_kernel",
[&] {
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(),
input.data_ptr<scalar_t>(),
scale.data_ptr<float>(),
num_elems);
});
}

void dynamic_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
Expand Down
11 changes: 8 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,15 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,


# fp8
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
def static_scaled_fp8_quant(input: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
pcmoritz marked this conversation as resolved.
Show resolved Hide resolved
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
vllm_ops.static_scaled_fp8_quant(output, input, scale)
return output

def dynamic_scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
vllm_ops.scaled_fp8_quant(output, input, scale)
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
return output, scale


Expand Down
33 changes: 21 additions & 12 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.utils import is_hip

logger = init_logger(__name__)
Expand Down Expand Up @@ -220,22 +222,25 @@ def moe_align_block_size(


def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
B_scale: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
A_scale: Optional[torch.Tensor], B_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int,
config: Dict[str, Any], compute_type: tl.dtype,
use_fp8: bool) -> None:
linear_method: Optional[LinearMethodBase]) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1

if not use_fp8:
A_scale = None
if not isinstance(linear_method, Fp8LinearMethod):
assert A_scale is None
assert B_scale is None
else:
A, A_scale = ops.scaled_fp8_quant(A)
elif linear_method.quant_config.act_scaling == "static":
A = ops.static_scaled_fp8_quant(A, A_scale)
assert B_scale is not None
elif linear_method.quant_config.act_scaling == "dynamic":
A, A_scale = ops.dynamic_scaled_fp8_quant(A)
assert B_scale is not None

grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
Expand Down Expand Up @@ -265,7 +270,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
use_fp8=use_fp8,
use_fp8=isinstance(linear_method, Fp8LinearMethod),
**config,
)

Expand Down Expand Up @@ -315,9 +320,11 @@ def fused_moe(
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
linear_method: Optional[LinearMethodBase] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
Expand Down Expand Up @@ -395,7 +402,7 @@ def fused_moe(
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2],
"float8" if use_fp8 else None)
"float8" if isinstance(linear_method, Fp8LinearMethod) else None)

if configs:
# If an optimal configuration map has been found, look up the
Expand Down Expand Up @@ -434,6 +441,7 @@ def fused_moe(
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
Expand All @@ -444,13 +452,14 @@ def fused_moe(
topk_ids.shape[1],
config,
compute_type=tl.float16,
use_fp8=use_fp8)
linear_method=linear_method)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))

invoke_fused_moe_kernel(intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
Expand All @@ -461,7 +470,7 @@ def fused_moe(
1,
config,
compute_type=tl.float16,
use_fp8=use_fp8)
linear_method=linear_method)

if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
Expand Down
13 changes: 11 additions & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
class FP8Config(QuantizationConfig):
"""Config class for FP8."""

config_file_optional = True
pcmoritz marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
act_scaling: str="dynamic",
) -> None:
self.act_scaling = act_scaling

@classmethod
def get_name(cls) -> str:
return "fp8"
Expand All @@ -30,11 +38,12 @@ def get_min_capability(cls) -> int:

@classmethod
def get_config_filenames(cls) -> List[str]:
return []
return ["quantize_config.json"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering: For static quantization, can we just have this scaling factor in config.json? I'm not sure if this is a better decision than having a separate quantization config file, but it seems feasible. WDYT?

Copy link
Collaborator Author

@pcmoritz pcmoritz Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is feasible. The reason why I put it into the safetensor files is because that's a better place to store tensors than the config.json and for some schemes, these tensors can be larger (e.g. for per-channel quantization). Since the checkpoint already needs to be rewritten (to add the quant_config.json and also possibly to convert weights to FP8), I don't think it is a big problem to rewrite the safetensors to include activation scales. This is also what https://huggingface.co/FriendliAI/Mistral-7B-Instruct-v0.2-fp8 does (note that their model format is otherwise not very useful since it stores the weights as int8).

I assume that a standard for FP8 will emerge, and I would expect it to store the scales in the safetensor files -- this is a no brainer for weight scales to keep them close to the weights and make sure they are consistent with the quantized weights but also makes sense for activation scales. Once we have a standard, we should use that. Right now, I don't think trying to invent our own standard in quantize_config.json is a good idea (since it involves a schema), whereas storing scales in the safetensor scales is pretty canonical and doesn't require us to invent a lot of convention.

These are my reasons -- let me know if you prefer otherwise :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also cc @robertgshaw2-neuralmagic who also thought about this a bunch I think :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @pcmoritz , the scales should be in a safetensors file

The mental model for the quantize_config.json is that it would hold metadata about what is in the safetensors file.

So examples could be:

  • Datatype of the weights
  • Whether the activations are static or dynamic (so we dont have to peek into safetensors)
  • Channelwise vs not, etc

For this first implementation, we don't this, but if we start supporting various different schemes, then we will (since we need to know this when create_weights is called - which happens before we look at the safetensors file)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, mostly quantize_config is not used anymore, even for AutoGPTQ

The config.json typically has a quantization_config.

https://huggingface.co/astronomer/Llama-3-8B-Instruct-GPTQ-8-Bit/blob/main/config.json

We only fall back to quantize_config.json if quantization_config is not found in the config.json


@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
return cls()
act_scaling = cls.get_from_keys(config, ["act_scaling"])
return cls(act_scaling)

def get_linear_method(self) -> "Fp8LinearMethod":
return Fp8LinearMethod(self)
Expand Down
11 changes: 5 additions & 6 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,17 @@ def get_quant_config(model_config: ModelConfig,
else:
hf_folder = model_name_or_path

possible_config_filenames = quant_cls.get_config_filenames()

# If the quantization config is not found, use the default config.
if not possible_config_filenames:
return quant_cls()

config_files = glob.glob(os.path.join(hf_folder, "*.json"))
possible_config_filenames = quant_cls.get_config_filenames()

quant_config_files = [
f for f in config_files if any(
f.endswith(x) for x in possible_config_filenames)
]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for GPTQ/AWQ, having a quantize_config.json is not necessarily required. So I think this check could break models that have:

  • quantization_config in config.json
  • no quantize_config.json

For example: https://huggingface.co/casperhansen/llama-3-70b-instruct-awq/tree/main

Our CI does not seem to have any models with this setup

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I fixed this now by removing the quantize_config.json support -- we can just use config.json for specifying the quantization for FP8 checkpoints for the time being :)

# If the quantization config is optional and not provided, use the default config.
if getattr(quant_cls, "config_file_optional", False) and not quant_config_files:
return quant_cls()
if len(quant_config_files) == 0:
raise ValueError(
f"Cannot find the config file for {model_config.quantization}")
Expand Down
60 changes: 45 additions & 15 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def __init__(
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size
# FIXME(pcmoritz): Make this more general to support different
# quantization schemes
self.use_fp8 = isinstance(linear_method, Fp8LinearMethod)
self.linear_method = linear_method

if params_dtype is None:
params_dtype = torch.get_default_dtype()
Expand All @@ -104,22 +102,41 @@ def __init__(
device="cuda",
dtype=self.params_dtype))

set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2s, {
"weight_loader": self.weight_loader,
})

use_fp8 = isinstance(linear_method, Fp8LinearMethod)

# Scaling factors for FP8 weights
self.ws_scale = nn.Parameter(
torch.ones(
self.num_total_experts, device="cuda", dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None
requires_grad=False) if use_fp8 else None
self.w2s_scale = nn.Parameter(
torch.ones(
self.num_total_experts, device="cuda", dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None

set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2s, {
"weight_loader": self.weight_loader,
})
requires_grad=False) if use_fp8 else None

# Scaling factors for FP8 activations
need_act_scales = use_fp8 and linear_method.quant_config.act_scaling == "static"
self.as_scale = nn.Parameter(
torch.zeros(1, device="cuda", dtype=torch.float32),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Theoretically, we shouldn't use "cuda" in the model code. Since the GPU worker sets "cuda" as the default device in torch, device="cuda" is not necessary. Also, it's not good for the compatibility with non-CUDA devices.

This rule is violated for Mixtral and other MoE models unfortunately. 😢

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll make a follow up PR to remove the device="cuda" -- since we also specify it explicitly for the other parameters, I don't want to be inconsistent for this PR :)

requires_grad=False) if need_act_scales else None
self.a2s_scale = nn.Parameter(
torch.zeros(1, device="cuda", dtype=torch.float32),
requires_grad=False) if need_act_scales else None

if need_act_scales:
set_weight_attrs(self.as_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.a2s_scale, {
"weight_loader": self.weight_loader,
})

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR, I think we should have an MoELayer that is shared across models

All of these changes are currently only impacting Mixtral, but could also be applied to other models. Since we have all this generic logic in the model definitions, we are losing out at running others with these features

weight_name: str, expert_id: int):
Expand All @@ -134,9 +151,12 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if "activation_scale" in weight_name:
param_data[:] = param_data[:].max(loaded_weight)
print("loaded scale", weight_name, param_data)
pcmoritz marked this conversation as resolved.
Show resolved Hide resolved

def process_weights_after_loading(self):
if self.use_fp8:
if isinstance(self.linear_method, Fp8LinearMethod):
ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts):
Expand All @@ -159,9 +179,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
self.top_k,
renormalize=True,
inplace=True,
use_fp8=self.use_fp8,
linear_method=self.linear_method,
w1_scale=self.ws_scale,
w2_scale=self.w2s_scale)
w2_scale=self.w2s_scale,
a1_scale=self.as_scale,
a2_scale=self.a2s_scale)

if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
Expand Down Expand Up @@ -443,11 +465,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
]

expert_params_mapping = [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("a_scale" if activation_name in ["a1", "a3"] else "a2_scale",
f"experts.{expert_id}.{activation_name}.activation_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for activation_name in ["a1", "a2", "a3"]
]

params_dict = dict(self.named_parameters())
Expand Down
Loading