diff --git a/src/tensor_parallel/slicing_configs.py b/src/tensor_parallel/slicing_configs.py index b1a3f5f..a5b0918 100644 --- a/src/tensor_parallel/slicing_configs.py +++ b/src/tensor_parallel/slicing_configs.py @@ -406,6 +406,63 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev return config +def get_mixtral_config(model_config: PretrainedConfig, devices: Sequence[torch.device]) -> Config: + assert model_config.model_type == "mixtral", f"Trying to pass {model_config.model_type} as mixtral config" + + world_size = len(devices) + head_dim = model_config.hidden_size // model_config.num_attention_heads + num_kv = model_config.num_key_value_heads + q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads + new_modeling = True + + gather_kv_across_ranks = CollectiveOperation( + world_size=world_size, func=lambda *kvs: gather_kv(*kvs, world_size=world_size) + ) # this operation ensures that we get attention cache for all heads on each device + + config = Config( + state_rules={ + # MixtralAttention + r".*self_attn\.q_proj\.weight$": SplitInChunks( + world_size=world_size, dim=0, chunk_size=q_per_kv * head_dim + ), + r".*self_attn\.k_proj\.weight$": SplitInChunks(world_size=world_size, dim=0, chunk_size=head_dim), + r".*self_attn\.v_proj\.weight$": SplitInChunks(world_size=world_size, dim=0, chunk_size=head_dim), + r".*self_attn\.o_proj\.weight$": SplitInChunks( + world_size=world_size, dim=1, chunk_size=q_per_kv * head_dim + ), + # MixtralFeedForward + r".*experts\.\d+\.w1\.weight$": Split(world_size=world_size, dim=0), + r".*experts\.\d+\.w2\.weight$": Split(world_size=world_size, dim=1), + r".*experts\.\d+\.w3\.weight$": Split(world_size=world_size, dim=0), + # MixtralModel + r".*embed_tokens.weight$": Split(world_size=world_size, dim=1), + r".*lm_head\.weight$": Split(world_size=world_size, dim=0), + }, + input_rules={ + r".*self_attn$": {"past_key_value": select_kv_for_rank}, + }, + output_rules={ + r".*self_attn$": {0: "sum", 2: gather_kv_across_ranks}, + r".*experts\.\d+$": {0: "sum"}, + r".*embed_tokens$": {0: "gather -1"}, + r".*lm_head$": {0: "gather -1"}, + }, + attr_rules={ + r".*self_attn$": { + "hidden_size": partial(split_inner_dim, num_heads=num_kv, world_size=world_size), + "num_heads": lambda n, rank: q_per_kv + * split_num_heads(n // q_per_kv, rank=rank, world_size=world_size), + } + }, + ) + + config.attr_rules[re.compile(".*self_attn$")]["num_key_value_heads"] = partial( + split_num_heads, world_size=world_size + ) + + return config + + def get_refined_web_config(model_config: PretrainedConfig, devices: Sequence[torch.device]) -> Config: # We can't use `RWConfig`` since it's custom code assert model_config.model_type == "RefinedWeb", f"Trying to pass {model_config.model_type} as RefinedWeb config" @@ -470,5 +527,6 @@ def get_refined_web_config(model_config: PretrainedConfig, devices: Sequence[tor "gpt_neox": get_gpt_neox_config, "codegen": get_codegen_config, "llama": get_llama_config, + "mixtral": get_mixtral_config, "RefinedWeb": get_refined_web_config, }