diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 93ec5c12536fb..87068644112c0 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -245,13 +245,11 @@ def _yarn_find_correction_range(low_rot: int, def _yarn_linear_ramp_mask(low: float, high: float, dim: int, - dtype: torch.dtype, - device: torch.device) -> torch.Tensor: + dtype: torch.dtype) -> torch.Tensor: if low == high: high += 0.001 # Prevent singularity - linear_func = (torch.arange(dim, dtype=dtype, device=device) - - low) / (high - low) + linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func