Skip to content

Commit

Permalink
Remove unsqueeze
Browse files Browse the repository at this point in the history
  • Loading branch information
dolszewska committed Sep 9, 2024
1 parent 0440fb2 commit 9916b6b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def _prepare_decode(
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append(generation_token)
input_tokens.append([generation_token])

seq_len = seq_data.get_len()
position = seq_len - 1
Expand Down Expand Up @@ -928,7 +928,7 @@ def _prepare_decode(
lora_logits_mask = lora_mask
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device).unsqueeze(-1)
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
Expand Down

0 comments on commit 9916b6b

Please sign in to comment.