From e04e56d981bc3fc81f94307c5c274ed065a13bd3 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Wed, 6 Mar 2024 19:50:23 +0000 Subject: [PATCH] fixed bug in multi GPU setting --- vllm/model_executor/models/jais.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 6d4cefe4f4146..5cbc38ece9fa9 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -93,8 +93,9 @@ def forward( # Get the logits for the next tokens. logits = self._get_logits(hidden_states, embedding, embedding_bias) - logits *= torch.tensor(float(output_logits_scale), - dtype=logits.dtype) + if logits is not None: + logits *= torch.tensor(float(output_logits_scale), + dtype=logits.dtype) # Only perform sampling in the driver worker. # Note: `_get_logits` is still distributed across TP workers because