Skip to content

Commit

Permalink
[Model] Cohere CommandR+ (vllm-project#3829)
Browse files Browse the repository at this point in the history
  • Loading branch information
saurabhdash2512 authored and joerunde committed Apr 11, 2024
1 parent 91b83ad commit 1842394
Showing 1 changed file with 40 additions and 8 deletions.
48 changes: 40 additions & 8 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn.parameter import Parameter
from transformers import CohereConfig

from vllm.attention import Attention, AttentionMetadata
Expand All @@ -39,20 +40,21 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput


class LayerNorm(nn.Module):

def __init__(self, hidden_size, eps=1e-5, bias=False):
def __init__(self, param_shape=None, eps=1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None
self.weight = nn.Parameter(torch.ones(param_shape))
self.variance_epsilon = eps
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})

def forward(self, hidden_states, residuals=None):
input_dtype = hidden_states.dtype
Expand All @@ -62,10 +64,20 @@ def forward(self, hidden_states, residuals=None):
hidden_states = (hidden_states -
mean) * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight.to(torch.float32) * hidden_states
if self.bias is not None:
hidden_states = hidden_states + self.bias.to(torch.float32)
return hidden_states.to(input_dtype), residuals

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
param_data = param.data
if shard_dim is not None:
shard_size = param_data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)


# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):
Expand Down Expand Up @@ -131,6 +143,7 @@ def __init__(
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.rope_scaling = getattr(config, "rope_scaling", None)
self.use_qk_norm = getattr(config, "use_qk_norm", False)
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
Expand Down Expand Up @@ -159,6 +172,22 @@ def __init__(
self.scaling,
num_kv_heads=self.num_kv_heads,
)
if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads,
self.head_dim),
eps=config.layer_norm_eps)
self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
self.head_dim),
eps=config.layer_norm_eps)

def _apply_qk_norm(self, q, k):
q = q.view(*q.shape[:-1], -1, self.head_dim)
k = k.view(*k.shape[:-1], -1, self.head_dim)
q, _ = self.q_norm(q)
k, _ = self.k_norm(k)
q = q.view(*q.shape[:-2], -1)
k = k.view(*k.shape[:-2], -1)
return q, k

def forward(
self,
Expand All @@ -169,6 +198,8 @@ def forward(
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
Expand All @@ -186,7 +217,7 @@ def __init__(self,
self.self_attn = CohereAttention(config, linear_method=linear_method)

self.mlp = CohereMLP(config, linear_method=linear_method)
self.input_layernorm = LayerNorm(config.hidden_size,
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)

def forward(
Expand Down Expand Up @@ -229,7 +260,8 @@ def __init__(
CohereDecoderLayer(config, linear_method=linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.norm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)

def forward(
self,
Expand Down

0 comments on commit 1842394

Please sign in to comment.