diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index c5dd1a63e2f7a..9ff9ba298588a 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -96,13 +96,11 @@ def __init__( torch.empty(self.num_total_experts, 2 * self.intermediate_size, self.hidden_size, - device="cuda", dtype=self.params_dtype)) self.w2s = nn.Parameter( torch.empty(self.num_total_experts, self.hidden_size, self.intermediate_size, - device="cuda", dtype=self.params_dtype)) set_weight_attrs(self.ws, { @@ -114,22 +112,20 @@ def __init__( # Scaling factors for FP8 weights self.ws_scale = nn.Parameter( - torch.ones( - self.num_total_experts, device="cuda", dtype=torch.float32), + torch.ones(self.num_total_experts, dtype=torch.float32), requires_grad=False) if self.use_fp8 else None self.w2s_scale = nn.Parameter( - torch.ones( - self.num_total_experts, device="cuda", dtype=torch.float32), + torch.ones(self.num_total_experts, dtype=torch.float32), requires_grad=False) if self.use_fp8 else None # Scaling factors for FP8 activations need_act_scales = (self.use_fp8 and quant_config.activation_scheme == "static") self.as_scale = nn.Parameter( - torch.zeros(1, device="cuda", dtype=torch.float32), + torch.zeros(1, dtype=torch.float32), requires_grad=False) if need_act_scales else None self.a2s_scale = nn.Parameter( - torch.zeros(1, device="cuda", dtype=torch.float32), + torch.zeros(1, dtype=torch.float32), requires_grad=False) if need_act_scales else None if need_act_scales: