diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 6d4cefe4f4146..5cbc38ece9fa9 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -93,8 +93,9 @@ def forward( # Get the logits for the next tokens. logits = self._get_logits(hidden_states, embedding, embedding_bias) - logits *= torch.tensor(float(output_logits_scale), - dtype=logits.dtype) + if logits is not None: + logits *= torch.tensor(float(output_logits_scale), + dtype=logits.dtype) # Only perform sampling in the driver worker. # Note: `_get_logits` is still distributed across TP workers because