From 99bb48dca87ae65481a225a5f184245823d2e7a9 Mon Sep 17 00:00:00 2001 From: Qiang Li Date: Thu, 24 Oct 2024 15:53:41 -0700 Subject: [PATCH 1/4] Add fp8 support for Llama model family on Navi4x --- CMakeLists.txt | 16 +++++++++- cmake/utils.cmake | 30 +++++++++++++++++++ csrc/quantization/fp8/common.cu | 6 ++-- examples/offline_inference.py | 2 +- vllm/_custom_ops.py | 4 +-- .../model_executor/layers/quantization/fp8.py | 14 ++++----- vllm/model_executor/models/llama.py | 14 +++++---- vllm/utils.py | 5 ++++ 8 files changed, 72 insertions(+), 19 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b13e9d08c7f6..f3fd608894e6d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,7 +37,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12") set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") # Supported AMD GPU architectures. -set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101") +set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1200") # # Supported/expected torch versions for CUDA/ROCm. @@ -172,6 +172,20 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result") # get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG}) +# +# Get supported FP8 format based on GPU arches +# +get_supported_fp8_format(FP8_FORMAT ${VLLM_GPU_LANG} "${VLLM_GPU_ARCHES}") +if(${FP8_FORMAT} STREQUAL "E4M3FN") + message(STATUS "FP8 format: E4M3FN") + list(APPEND VLLM_GPU_FLAGS "-DUSE_CUDA_FP8_FORMAT") +elseif(${FP8_FORMAT} STREQUAL "E4M3FNUZ") + message(STATUS "FP8 format: E4M3FNUZ") + list(APPEND VLLM_GPU_FLAGS "-DUSE_HIP_FP8_FORMAT") +elseif(${FP8_FORMAT} STREQUAL "CONFLICT") + message(FATAL_ERROR "Target architectures support different types of FP8 formats!") +endif() + # # Set nvcc parallelism. # diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 24bb7299338ac..3866ba58a6e11 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -435,3 +435,33 @@ function (define_gpu_extension_target GPU_MOD_NAME) install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME}) endfunction() + + +# gfx12xx should not be compiled together with gfx94x (MI300) because they support different types of FP8 format. +# FP8_FORMAT will be returned (E4M3FN / E4M3FNUZ / NONE / CONFLICT) +macro (get_supported_fp8_format FP8_FORMAT GPU_LANG GPU_ARCHES) + set(_USING_CUDA_FP8_FORMAT "FALSE") + set(_USING_HIP_FP8_FORMAT "FALSE") + + if (NOT (${GPU_LANG} STREQUAL "HIP")) + set(_USING_CUDA_FP8_FORMAT "TRUE") + else() + foreach (_ARCH ${GPU_ARCHES}) + if (_ARCH MATCHES "gfx94.") + set(_USING_HIP_FP8_FORMAT "TRUE") + elseif(_ARCH MATCHES "gfx12..") + set(_USING_CUDA_FP8_FORMAT "TRUE") + endif() + endforeach() + endif() + + if ((${_USING_CUDA_FP8_FORMAT} STREQUAL "FALSE") AND (${_USING_HIP_FP8_FORMAT} STREQUAL "FALSE")) + set(FP8_FORMAT "NONE") + elseif((${_USING_CUDA_FP8_FORMAT} STREQUAL "FALSE") AND (${_USING_HIP_FP8_FORMAT} STREQUAL "TRUE")) + set(FP8_FORMAT "E4M3FNUZ") + elseif((${_USING_CUDA_FP8_FORMAT} STREQUAL "TRUE") AND (${_USING_HIP_FP8_FORMAT} STREQUAL "FALSE")) + set(FP8_FORMAT "E4M3FN") + else() + set(FP8_FORMAT "CONFLICT") + endif() +endmacro() diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index f2c609c1b68c3..c05ec89f03cc8 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -7,7 +7,7 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#ifndef USE_ROCM +#if defined(USE_CUDA_FP8_FORMAT) #include #include #else @@ -15,7 +15,7 @@ #include #endif -#ifndef USE_ROCM +#if defined(USE_CUDA_FP8_FORMAT) using FP8_TYPE = c10::Float8_e4m3fn; C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); @@ -50,7 +50,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, } float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); -#ifndef USE_ROCM +#if defined(USE_CUDA_FP8_FORMAT) return static_cast(r); #else // Use hardware cvt instruction for fp8 on rocm diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f6..b8cd6f6a11398 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="/models/opt-125m") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 12b9d97091274..49ff0332a6981 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -9,7 +9,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.scalar_type import ScalarType -from vllm.utils import is_hip +from vllm.utils import is_hip, is_navi4x logger = init_logger(__name__) @@ -711,7 +711,7 @@ def scaled_fp8_quant( assert (input.ndim == 2) shape: Union[Tuple[int, int], torch.Size] = input.shape # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - out_dtype: torch.dtype = torch.float8_e4m3fnuz if is_hip() \ + out_dtype: torch.dtype = torch.float8_e4m3fnuz if is_hip() and not is_navi4x() \ else torch.float8_e4m3fn if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 633cd5f49fc6a..49440ee32c110 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -27,7 +27,7 @@ PerTensorScaleParameter) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import is_hip, print_warning_once +from vllm.utils import is_hip, print_warning_once, is_navi4x ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -227,8 +227,8 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight weight_scale = layer.weight_scale - # If rocm, use float8_e4m3fnuz. - if is_hip(): + # If rocm (except Navi4x), use float8_e4m3fnuz. + if is_hip() and not is_navi4x(): weight, weight_scale, input_scale = \ normalize_e4m3fn_to_e4m3fnuz( weight=weight, @@ -378,9 +378,9 @@ def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: - # If rocm, use float8_e4m3fnuz as dtype + # If rocm (except Navi4x), use float8_e4m3fnuz as dtype fp8_dtype = torch.float8_e4m3fnuz \ - if is_hip() else torch.float8_e4m3fn + if is_hip() and not is_navi4x() else torch.float8_e4m3fn w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) @@ -427,8 +427,8 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_input_scale.max(), requires_grad=False) layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False) - # If rocm, normalize the weights and scales to e4m3fnuz - if is_hip(): + # If rocm (except Navi4x), normalize the weights and scales to e4m3fnuz + if is_hip() and not is_navi4x(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 447ea2b9348e1..79d60618b0ef2 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -54,7 +54,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput -from vllm.utils import is_hip +from vllm.utils import is_hip, is_navi4x from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, @@ -87,14 +87,15 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.down_proj", ) - self.use_fp8 = isinstance(quant_config, Fp8Config) + self.use_fp8 = isinstance(quant_config, Fp8Config) if is_hip() and not is_navi4x() else False if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") self.act_fn = SiluAndMul() def forward(self, x): - if is_hip() and x.shape[0] == 1 and x.shape[1] == 1: + # Navi4x is an exception among HIP devices -- it uses the same FP8 format with CUDA devices + if is_hip() and not is_navi4x() and x.shape[0] == 1 and x.shape[1] == 1: out = torch.empty(x.shape[0], self.gate_up_proj.weight.shape[0] // 2, dtype=x.dtype, @@ -189,8 +190,10 @@ def __init__( cache_config=cache_config, quant_config=quant_config, ) + # For CUDA devices and Navi4x, attn_fp8_out will be set to false. self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \ and is_hip() \ + and not is_navi4x() \ and isinstance(quant_config, Fp8Config) def forward( @@ -225,7 +228,7 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - self.use_fp8 = isinstance(quant_config, Fp8Config) + self.use_fp8 = isinstance(quant_config, Fp8Config) if is_hip() and not is_navi4x() else False rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( @@ -456,7 +459,8 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: if not isinstance(self.layers[layer_idx], nn.Identity): layer_self_attn = self.layers[layer_idx].self_attn - if is_hip(): + # Navi4x quantization should be treated as CUDA devices. + if is_hip() and not is_navi4x(): # The scaling factor convention we are assuming is # quantized_value * scaling_factor ~= true_value # which is consistent with the practice of setting diff --git a/vllm/utils.py b/vllm/utils.py index 788e0d424ed52..e61ba39ac6a77 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -7,6 +7,7 @@ import inspect import ipaddress import os +import re import random import socket import subprocess @@ -425,6 +426,10 @@ def is_hip() -> bool: return torch.version.hip is not None +def is_navi4x() -> bool: + return re.match("gfx12..", os.environ.get("PYTORCH_ROCM_ARCH", "")) is not None + + @lru_cache(maxsize=None) def is_cpu() -> bool: from importlib.metadata import PackageNotFoundError, version From 68b9476dc4b69754db7c7a994d1a2a116316283e Mon Sep 17 00:00:00 2001 From: Qiang Li Date: Thu, 24 Oct 2024 17:16:36 -0700 Subject: [PATCH 2/4] Fix a typo in example --- examples/offline_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index b8cd6f6a11398..9b758fa2479f6 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="/models/opt-125m") +llm = LLM(model="facebook/opt-125m") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) From eb931f4f1eb360fdad1d95b98a37329e3430df10 Mon Sep 17 00:00:00 2001 From: Qiang Li Date: Fri, 25 Oct 2024 12:42:16 -0700 Subject: [PATCH 3/4] [misc] 1. format updates (split long lines); 2. change implementation of is_navi4x ( from env variable to cuda query) --- vllm/model_executor/layers/quantization/fp8.py | 5 +++-- vllm/model_executor/models/llama.py | 11 +++++++---- vllm/utils.py | 11 +++++++++-- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 49440ee32c110..225b8cdb82c76 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -27,7 +27,7 @@ PerTensorScaleParameter) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import is_hip, print_warning_once, is_navi4x +from vllm.utils import is_hip, is_navi4x, print_warning_once ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -427,7 +427,8 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_input_scale.max(), requires_grad=False) layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False) - # If rocm (except Navi4x), normalize the weights and scales to e4m3fnuz + # If rocm (except Navi4x, which uses e4m3fn), + # normalize the weights and scales to e4m3fnuz if is_hip() and not is_navi4x(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = \ diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 79d60618b0ef2..9ed9eabbd63b5 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -87,15 +87,17 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.down_proj", ) - self.use_fp8 = isinstance(quant_config, Fp8Config) if is_hip() and not is_navi4x() else False + self.use_fp8 = isinstance(quant_config, Fp8Config) \ + if is_hip() and not is_navi4x() else False if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") self.act_fn = SiluAndMul() def forward(self, x): - # Navi4x is an exception among HIP devices -- it uses the same FP8 format with CUDA devices - if is_hip() and not is_navi4x() and x.shape[0] == 1 and x.shape[1] == 1: + # Navi4x is diff from other HIP devices by using e4m3fn fp8 format + if is_hip() and not is_navi4x() \ + and x.shape[0] == 1 and x.shape[1] == 1: out = torch.empty(x.shape[0], self.gate_up_proj.weight.shape[0] // 2, dtype=x.dtype, @@ -228,7 +230,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - self.use_fp8 = isinstance(quant_config, Fp8Config) if is_hip() and not is_navi4x() else False + self.use_fp8 = isinstance(quant_config, Fp8Config) \ + if is_hip() and not is_navi4x() else False rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( diff --git a/vllm/utils.py b/vllm/utils.py index e61ba39ac6a77..7c99818eb01d2 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -7,8 +7,8 @@ import inspect import ipaddress import os -import re import random +import re import socket import subprocess import sys @@ -426,8 +426,15 @@ def is_hip() -> bool: return torch.version.hip is not None +@lru_cache(maxsize=None) def is_navi4x() -> bool: - return re.match("gfx12..", os.environ.get("PYTORCH_ROCM_ARCH", "")) is not None + if not torch.cuda.is_available(): + return False + # All (visible) GPUs must be of the same type, + # otherwise FP8 results can't be guaranteed. + archName = torch.cuda.get_device_properties('cuda').gcnArchName + return (archName is not None) and \ + (re.match("gfx12[0-9]{2}", archName) is not None) @lru_cache(maxsize=None) From ae81124673577a6dc998785c296606275ded24cd Mon Sep 17 00:00:00 2001 From: Qiang Li Date: Fri, 25 Oct 2024 14:16:39 -0700 Subject: [PATCH 4/4] [misc] 1. Add platform detection before using torch.cuda; 2. Remove unnecessary detection of Navi4x platform; --- vllm/model_executor/models/llama.py | 4 +--- vllm/utils.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 9ed9eabbd63b5..4a39b686245c6 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -95,9 +95,7 @@ def __init__( self.act_fn = SiluAndMul() def forward(self, x): - # Navi4x is diff from other HIP devices by using e4m3fn fp8 format - if is_hip() and not is_navi4x() \ - and x.shape[0] == 1 and x.shape[1] == 1: + if is_hip() and x.shape[0] == 1 and x.shape[1] == 1: out = torch.empty(x.shape[0], self.gate_up_proj.weight.shape[0] // 2, dtype=x.dtype, diff --git a/vllm/utils.py b/vllm/utils.py index 7c99818eb01d2..af857ca315b38 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -428,7 +428,7 @@ def is_hip() -> bool: @lru_cache(maxsize=None) def is_navi4x() -> bool: - if not torch.cuda.is_available(): + if not is_hip() or not torch.cuda.is_available(): return False # All (visible) GPUs must be of the same type, # otherwise FP8 results can't be guaranteed.