diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 5b8311a33c361..e2b4778b94b9e 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -258,19 +258,20 @@ def test_reshape_and_cache_flash( del key_caches del value_caches + k_scale = key.amax().item() / 256 + v_scale = value.amax().item() / 256 + # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(cloned_key_cache, key_cache) + ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(cloned_value_cache, value_cache) + ops.convert_fp8(cloned_value_cache, value_cache, v_scale, + kv_cache_dtype) else: cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() - # Using default kv_scale - k_scale = v_scale = 1.0 - # Call the reshape_and_cache kernel. opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, @@ -281,9 +282,15 @@ def test_reshape_and_cache_flash( if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(result_key_cache, key_cache) + ops.convert_fp8(result_key_cache, + key_cache, + k_scale, + kv_dtype=kv_cache_dtype) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(result_value_cache, value_cache) + ops.convert_fp8(result_value_cache, + value_cache, + v_scale, + kv_dtype=kv_cache_dtype) # Run the reference implementation. block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 234c87d5c4edb..658805d35be0a 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -759,8 +759,6 @@ def forward( v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: - assert k_scale == 1.0 and v_scale == 1.0, ( - "key/v_scale is not supported in FlashInfer.") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " @@ -874,7 +872,12 @@ def unified_flash_infer( assert prefill_meta is not None assert prefill_meta.prefill_wrapper is not None prefill_output = prefill_meta.prefill_wrapper.forward( - query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True) + query, + kv_cache, + logits_soft_cap=logits_soft_cap, + causal=True, + k_scale=k_scale, + v_scale=v_scale) if decode_meta := attn_metadata.decode_metadata: assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata.decode_wrapper is not None diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index dc5f47eb9b0fb..9694f2b8208e2 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -141,8 +141,11 @@ def create_weights( layer.register_parameter("input_scale", scale) def process_weights_after_loading(self, layer: Module) -> None: - max_w_scale, weight = requantize_with_max_scale( - layer.weight, layer.weight_scale, layer.logical_widths) + weight = layer.weight + max_w_scale = layer.weight_scale.max() + if not (layer.weight_scale == layer.weight_scale[0]).all(): + max_w_scale, weight = requantize_with_max_scale( + layer.weight, layer.weight_scale, layer.logical_widths) layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) layer.input_scale = Parameter(layer.input_scale.max(),