Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with Llama conversion for new release #24

Closed
evellasques opened this issue Apr 4, 2024 · 5 comments
Closed

Issue with Llama conversion for new release #24

evellasques opened this issue Apr 4, 2024 · 5 comments

Comments

@evellasques
Copy link
Contributor

I noticed that in the latest release, llama_module.py was replaced with falcon_module.py. And then, in test_llama.sh, you rely on megatron_gpt_pretraining.py (which relies on MegatronGPTModel instead of llama_module.py).

The problem is, MegatronGPTModel at some point relies on transformer.py (instead of llama_module.py) and there, for Swiglu, you've replaced the two separate MLP layers (dense_h_to_4h and dense_h_to_4h_2) with a single one, twice as large:

        self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
            hidden_size,
            2*ffn_hidden_size if self.glu_activation_family else ffn_hidden_size,
            gather_output=False,
            init_method=init_method,
            skip_bias_add=True,
            resume_from_checkpoint=resume_from_checkpoint,
            use_cpu_initialization=use_cpu_initialization,
            bias=bias,
            sequence_parallel_enabled=sequence_parallel,
            no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce,
            gradient_accumulation_fusion=gradient_accumulation_fusion,
            transfer_with_static_ring=transfer_with_static_ring,
        )

While in llama_module.pyyou had:

 self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
            hidden_size,
            ffn_hidden_size,  # NOTE: When using geglu, divide ffn dim by 2/3 to keep overall params the same.
            gather_output=False,
            init_method=init_method,
            skip_bias_add=True,
            use_cpu_initialization=use_cpu_initialization,
            bias=bias,
            sequence_parallel_enabled=sequence_parallel,
            no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce,
            gradient_accumulation_fusion=gradient_accumulation_fusion,
            transfer_with_static_ring=transfer_with_static_ring,
        )

        if activation in ['geglu', 'reglu', 'swiglu']:
            # Separate linear layer for *GLU activations.
            # Source: https://github.com/huggingface/transformers/blob/bee361c6f1f7704f8c688895f2f86f6e5ff84727/src/transformers/models/t5/modeling_t5.py#L292
            self.dense_h_to_4h_2 = tensor_parallel.ColumnParallelLinear(

But then you would have to change the checkpoint conversion script for llama as well, it's currently:

translation = {
        "model.language_model.embedding.word_embeddings.weight": (1, "model.embed_tokens.weight", 0, 0),
        # a['model']['language_model']['word_embeddings']['weight']
        "input_layernorm.weight": (0, "input_layernorm.weight", None, 0),
        "self_attention.query_key_value.weight": (1, "self_attn.query_key_value.weight", 0, 0),
        "self_attention.dense.weight": (1, "self_attn.o_proj.weight", 1, 0),
        "post_attention_layernorm.weight": (0, "post_attention_layernorm.weight", None, 0),
        "self_attention.core_attention.rotary_emb.inv_freq": (0, "self_attn.rotary_emb.inv_freq", None, 0),
        "mlp.dense_h_to_4h.weight": (1, "mlp.gate_proj.weight", 0, 0),
        "mlp.dense_h_to_4h_2.weight": (1, "mlp.up_proj.weight", 0, 0),
        "mlp.dense_4h_to_h.weight": (1, "mlp.down_proj.weight", 1, 0),
        "model.language_model.encoder.final_layernorm.weight": (0, "model.norm.weight", None, 0),
        "model.language_model.output_layer.weight": (1, "lm_head.weight", 0, 0),
    }

This is currently causing a crash when I try to load a checkpoint converted from HF Llama since it expects dense_h_to_4h to be a concatenation of gate_proj and up_proj (from the HF checkpoint):

RuntimeError: Error(s) in loading state_dict for MegatronGPTModel:
        size mismatch for model.language_model.encoder.layers.0.mlp.dense_h_to_4h.weight: copying a param with 
shape torch.Size([1376, 4096]) from checkpoint, the shape in current model is torch.Size([2752, 4096]).
@aws-kingrj
Copy link
Contributor

The MLP layer twice as large is actually equivalent to having two individual layers of half the size. If you use the MegatronGPTModel as we have, we expect it to work as a llama module, which is why we haven't updated that llama module.

@evellasques
Copy link
Contributor Author

evellasques commented Apr 5, 2024

The MLP layer twice as large is actually equivalent to having two individual layers of half the size. If you use the MegatronGPTModel as we have, we expect it to work as a llama module, which is why we haven't updated that llama module.

I understand that. My question is about how to fine-tune models converted using convert_hf_checkpoint_to_nemo_llama.py?

The problem is, that script is translating mlp.gate_proj.weight to mlp.dense_h_to_4h.weight and mlp.up_proj.weight to mlp.dense_h_to_4h_2.weight:

        "mlp.dense_h_to_4h.weight": (1, "mlp.gate_proj.weight", 0, 0),
        "mlp.dense_h_to_4h_2.weight": (1, "mlp.up_proj.weight", 0, 0),

Given these are merged in MegatronGPTModel, I think they should also be merged in convert_hf_checkpoint_to_nemo_llama.py, no?

I just did the following changes in convert_hf_checkpoint_to_nemo_llama.py:

    translation = {
        "model.language_model.embedding.word_embeddings.weight": (1, "model.embed_tokens.weight", 0, 0),
        # a['model']['language_model']['word_embeddings']['weight']
        "input_layernorm.weight": (0, "input_layernorm.weight", None, 0),
        "self_attention.query_key_value.weight": (1, "self_attn.query_key_value.weight", 0, 0),
        "self_attention.dense.weight": (1, "self_attn.o_proj.weight", 1, 0),
        "post_attention_layernorm.weight": (0, "post_attention_layernorm.weight", None, 0),
        "self_attention.core_attention.rotary_emb.inv_freq": (0, "self_attn.rotary_emb.inv_freq", None, 0),
        "mlp.dense_h_to_4h.weight": (1, "mlp.gate_proj_up_proj.weight", 0, 0),
        "mlp.dense_4h_to_h.weight": (1, "mlp.down_proj.weight", 1, 0),
        "model.language_model.encoder.final_layernorm.weight": (0, "model.norm.weight", None, 0),
        "model.language_model.output_layer.weight": (1, "lm_head.weight", 0, 0),
    }
for i in range(config['num_hidden_layers']):
        q = model_llama[f'model.layers.{i}.self_attn.q_proj.weight']
        k = model_llama[f'model.layers.{i}.self_attn.k_proj.weight']
        v = model_llama[f'model.layers.{i}.self_attn.v_proj.weight']
        model_llama[f'model.layers.{i}.self_attn.query_key_value.weight'] = torch.cat([q, k, v], dim=0)

        gate_proj = model_llama[f'model.layers.{i}.mlp.gate_proj.weight']
        up_proj = model_llama[f'model.layers.{i}.mlp.up_proj.weight']
        model_llama[f'model.layers.{i}.mlp.gate_proj_up_proj.weight'] = torch.cat([gate_proj, up_proj], dim=0)

        model_llama.pop(f'model.layers.{i}.self_attn.q_proj.weight')
        model_llama.pop(f'model.layers.{i}.self_attn.k_proj.weight')
        model_llama.pop(f'model.layers.{i}.self_attn.v_proj.weight')
        model_llama.pop(f'model.layers.{i}.mlp.gate_proj.weight')
        model_llama.pop(f'model.layers.{i}.mlp.up_proj.weight')

And then llama_test.sh no longer crashes. If you'd like, I can create a PR with that change (and also in convert_nemo_checkpoint_to_hf_llama.py).

@aws-kingrj
Copy link
Contributor

Yes, you're right can you update the PR with that and we can take a look

@evellasques
Copy link
Contributor Author

Yes, you're right can you update the PR with that and we can take a look

Just created a PR #26

@evellasques
Copy link
Contributor Author

Tracked by PR #26

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants