Skip to content

Commit

Permalink
[Model] Fix and clean commandr (vllm-project#3671)
Browse files Browse the repository at this point in the history
  • Loading branch information
esmeetu authored Mar 28, 2024
1 parent 6d9aa00 commit 10e6322
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import torch.utils.checkpoint
from torch import nn
from transformers import CohereConfig
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS

from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul
Expand All @@ -46,8 +45,6 @@
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput

KVCache = Tuple[torch.Tensor, torch.Tensor]


class LayerNorm(nn.Module):

Expand All @@ -70,9 +67,6 @@ def forward(self, hidden_states, residuals=None):
return hidden_states.to(input_dtype), residuals


ALL_LAYERNORM_LAYERS.append(LayerNorm)


# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):

Expand Down Expand Up @@ -137,7 +131,6 @@ def __init__(
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.rope_scaling = getattr(config, "rope_scaling", None)
self.is_causal = True
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
Expand Down Expand Up @@ -171,7 +164,7 @@ def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
Expand Down Expand Up @@ -200,7 +193,7 @@ def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -242,7 +235,7 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
Expand All @@ -269,7 +262,6 @@ def __init__(
) -> None:
super().__init__()
self.config = config
self.unpadded_vocab_size = config.vocab_size
self.linear_method = linear_method
self.logits_processor = LogitsProcessor(config.vocab_size,
scale=config.logit_scale)
Expand All @@ -281,7 +273,7 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
Expand Down

0 comments on commit 10e6322

Please sign in to comment.