Skip to content

Commit

Permalink
enable yuan autotp & add conv tp (microsoft#5428)
Browse files Browse the repository at this point in the history
This PR aims to enable yuan model autotp and add conv tp. 

Yuan model used shared qk. 
For example:
q_linear_out = [q1, q2, q3, q4, q5, ... , q16]
k_linear_out = [k1, k2, k3, k4, k5, ... , k16]

after share qk:
TP=1:
q' = [q1,q2,q3,q4,  q9,q10,q11,q12,  k1,k2 k3,k4,  k9,k10,k11,k12]
k' = [q5,q6,q7,q8,  q13,q14,q15,q16,  k5,k6,k7,k8,  k13,k14,k15,k16]
v' = [v1,v2,v3,v4,  v5,v6,v7,v8, v9,v10,v11,v12, v13,v14,v15,v16]

TP=2:
rank0:
q'_0 = [q1,q2,q3,q4, k1,k2 k3,k4]
k'_0 = [q5,q6,q7,q8, k5,k6,k7,k8]
v'_0 = [v1,v2,v3,v4, v5,v6,v7,v8] -> v'_0 is error! Expect value is:
[v1,v2,v3,v4, v9,v10,v11,v12]
rank1:
q'_1 = [q9,q10,q11,q12, k9,k10,k11,k12]
k'_1 = [q13,q14,q15,q16, k13,k14,k15,k16]
v'_1 = [v9,v10,v11,v12, v13,v14,v15,v16] -> v'_1 is error! Expect value
is: [v5,v6,v7,v8, v13,v14,v15,v16]

To avoid modifying the modeling code. We adjust the value and oproj
weight to fit this qk type.

We also added the conv tp to support some models that including the
heavy conv calculation. It is similar to the linear tp policy.
if  not last_conv_layer:

- 1. Divide the conv weight to each rank along the output channel
dimension.
-  2. To apply conv2d.

else:

- 1. Divide the conv weight to each rank along the input channel
dimension.
-  2. Apply conv2d.
-  3. Use allreduce to add outputs.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Lev Kurilenko <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
4 people authored Jun 18, 2024
1 parent 3bdd187 commit 8ea995e
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 7 deletions.
26 changes: 19 additions & 7 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_chunk_mlp
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


Expand Down Expand Up @@ -134,7 +134,7 @@ def is_load_module(module):
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding",
"Phi3RMSNorm"
"Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

Expand Down Expand Up @@ -331,6 +331,16 @@ def _replace(self, child, name, conv_linear_layer):
# For mixtral-7x8b, need to skip MoE gate linear replace.
if name == "block_sparse_moe.gate":
return child
# For Yuan model
if 'Yuan' in str(self.module):
if 'v_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), True)
return LinearLayer(weight=weight, bias=bias)
elif 'o_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), False)
return LinearAllreduce(weight, bias, self.mp_group)
# for phi3.
if 'gate_up_proj' in name:
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
Expand Down Expand Up @@ -412,11 +422,13 @@ def _slice_embedding(self, child, name, conv_linear_layer):
def update_mp_params(self, child):
if getattr(child, "replaced", False) == True:
return
for param in [
"n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads",
"all_head_size", "embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads", "kv_n_heads",
"d_model"
]:
param_list = [
"n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads", "all_head_size",
"embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads", "kv_n_heads", "d_model"
]
for param in param_list:
if "Yuan" in str(child) and 'embed_dim' in param_list:
param_list.remove('embed_dim')
if hasattr(child, param):
param_val = getattr(child, param)
setattr(child, param, get_shard_size(param_val, self.mp_size))
Expand Down
55 changes: 55 additions & 0 deletions deepspeed/module_inject/fusedqkv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,61 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):
return _bloom_type_transpose(src, mp_size)


# For share qk type:
# q = [q1,...,q_{n/4}, q_{n/2+1},...,q_{3n/4}, k1,...,k_{n/4}, k_{n/2+1},...,k_{3n/4}]
# k = [q_{n/4+1},...,q_{n/2}, q_{3n/4+1},...,qn, k_{n/4+1},...,k_{n/2}, k{3n/4+1},...,kn]
# Avoid modifying the modeling code. We adjust the value and oproj weight to fit this qk type.
def shard_value_with_share_qk(
weight,
bias,
rank,
world_size,
shard_value=True # True -> shard_value; False -> shard_oproj
):
if shard_value:
total_size = weight.shape[0]
weight_cat_dim = 0
else:
total_size = weight.shape[1]
weight_cat_dim = 1
num_heads = get_num_kv_heads()
head_dim = total_size // num_heads
assert (num_heads % world_size == 0)
if world_size > num_heads // 2:
RuntimeError(f"world_size {world_size} is larger than half of num_heads {num_heads}")
head_per_rank = num_heads // world_size
q_head_start = rank * head_per_rank
# mapping q_head to v_head
v_head_ids = []
i = 0
# mapping neighbor q_head to v_head
while i < head_per_rank:
v_head_ids.append(q_head_start // 2)
q_head_start += 2
i = i + 2

# mapping neighbor k_head to v_head
v_head_ids.extend([i + num_heads // 2 for i in v_head_ids])
sharded_weight = []
sharded_bias = []
for head_id in v_head_ids:
if shard_value:
sharded_weight.append(weight[head_id * head_dim:(head_id + 1) * head_dim])
if bias is not None:
sharded_bias.append(bias.data[head_id * head_dim:(head_id + 1) * head_dim])
else:
sharded_weight.append(weight[:, head_id * head_dim:(head_id + 1) * head_dim])
sharded_weight = torch.cat(sharded_weight, dim=weight_cat_dim)
if bias is not None:
if shard_value:
sharded_bias = torch.cat(sharded_bias, dim=0)
else:
bias = bias / float(world_size)
return torch.nn.Parameter(sharded_weight), torch.nn.Parameter(sharded_bias)
else:
return torch.nn.Parameter(sharded_weight), None


# For phi3 with chunk mlp, adjust the weight order.
def shard_chunk_mlp(
weight,
Expand Down
62 changes: 62 additions & 0 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,68 @@
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


class TensorParallelConv2d(nn.Module):

def __init__(self, conv, rank, world_size, shard_by_oc):
super().__init__()
self.rank = rank
self.world_size = world_size
self.shard_by_oc = shard_by_oc
self.shard_weights(conv)

# Split along the input/output channel depending on whether it is the last conv layer.
def shard_weights(self, conv):
if self.shard_by_oc:
total_size = conv.weight.shape[0]
else:
total_size = conv.weight.shape[1]
bias_data = None
cols_per_rank = [0]
for i in range(self.world_size - 1, -1, -1):
cols = total_size // self.world_size
if i < total_size % self.world_size:
cols += 1
cols_per_rank.append(cols_per_rank[-1] + cols)
weight_data = conv.weight.data
if self.shard_by_oc:
# not last conv layer, split output channel
weight_data = weight_data[cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]]
if conv.bias is not None:
bias_data = conv.bias.data[cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]]
else:
# last conv layer, split input channel
weight_data = weight_data[:, cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]]
if conv.bias is not None:
bias_data = conv.bias.data / float(self.world_size)
self.conv = nn.Conv2d(weight_data.shape[1], weight_data.shape[0], conv.kernel_size, conv.stride, conv.padding,
conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode)
self.conv.weight = torch.nn.Parameter(weight_data)
if conv.bias is not None:
self.conv.bias = torch.nn.Parameter(bias_data)
del conv

def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.conv(input)


class TensorParallelOcShardConv2d(TensorParallelConv2d):

def __init__(self, conv, rank, world_size):
super().__init__(conv, rank, world_size, True)


class TensorParallelIcShardConv2d(TensorParallelConv2d):

def __init__(self, conv, rank, world_size):
super().__init__(conv, rank, world_size, False)

def forward(self, input: torch.Tensor) -> torch.Tensor:
out = self.conv(input)
if self.world_size > 1:
dist.inference_all_reduce(out)
return out


class LinearAllreduce(nn.Module):

def __init__(self, weight, bias=None, mp_group=None):
Expand Down
27 changes: 27 additions & 0 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from deepspeed.accelerator import get_accelerator
from .replace_policy import replace_policies, generic_policies
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading
from .layers import TensorParallelOcShardConv2d, TensorParallelIcShardConv2d

from deepspeed import comm as dist
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads
Expand Down Expand Up @@ -340,6 +341,28 @@ def set_lm_head(module):
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
return module

def conv2d_parallel_shard_weights(model, rank, world_size):
# add conv policy
shard_oc_name = ["conv1"]
shard_ic_name = ["conv2"]
for name, sub_m in model.named_children():
for l_name, l_sub_m in sub_m.named_children():
if l_name in shard_oc_name:
TPConv2d = TensorParallelOcShardConv2d(
l_sub_m,
rank,
world_size,
)
setattr(sub_m, l_name, TPConv2d)
if l_name in shard_ic_name:
TPConv2d = TensorParallelIcShardConv2d(
l_sub_m,
rank,
world_size,
)
setattr(sub_m, l_name, TPConv2d)
conv2d_parallel_shard_weights(sub_m, rank, world_size)

if checkpoint_dict is not None and not config.replace_with_kernel_inject:
# AutoTP shard loading
checkpoint = checkpoint_dict["checkpoints"]
Expand All @@ -354,6 +377,10 @@ def set_lm_head(module):
pbar.update(1)
gc.collect()
replaced_module = set_lm_head(replaced_module)
# conv2d tp module replace
# Now is for yuan model. Add model list and conv policy to decide whether to replace conv.
if 'Yuan' in str(replaced_module):
conv2d_parallel_shard_weights(replaced_module, dist.get_rank(), dist.get_world_size())
else:
replaced_module = replace_module(model=model,
orig_class=orig_layer_impl,
Expand Down

0 comments on commit 8ea995e

Please sign in to comment.