Skip to content

Commit

Permalink
[BugFix] tensor.get_device() -> tensor.device (vllm-project#3604)
Browse files Browse the repository at this point in the history
  • Loading branch information
jikunshang authored Mar 25, 2024
1 parent ac8029a commit bb1af9c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _forward(
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]

self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)
Expand Down Expand Up @@ -142,7 +142,7 @@ def forward(
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
Expand Down

0 comments on commit bb1af9c

Please sign in to comment.