Skip to content

Commit

Permalink
fixed bug in multi GPU setting
Browse files Browse the repository at this point in the history
  • Loading branch information
Lalit Pradhan committed Mar 6, 2024
1 parent c9a3db5 commit e04e56d
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions vllm/model_executor/models/jais.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ def forward(

# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
logits *= torch.tensor(float(output_logits_scale),
dtype=logits.dtype)
if logits is not None:
logits *= torch.tensor(float(output_logits_scale),
dtype=logits.dtype)

# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
Expand Down

0 comments on commit e04e56d

Please sign in to comment.