Skip to content

Commit

Permalink
[Bugfix] Fix embedding to support 2D inputs (vllm-project#5829)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored and jimpang committed Jul 8, 2024
1 parent 51a5f2e commit b089ec6
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,11 @@ def forward(self, input_):
self.shard_indices.added_vocab_end_index)
else:
masked_input = input_
# Get the embeddings.
# Get the embeddings.
output_parallel = F.embedding(masked_input.long(), self.weight)
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(1), 0)
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
Expand Down

0 comments on commit b089ec6

Please sign in to comment.