Skip to content

Commit

Permalink
fix typos in convert_mixtral_nemo_to_hf.py and convert_starcoder2_nem…
Browse files Browse the repository at this point in the history
…o_to_hf.py (NVIDIA#9325)

Signed-off-by: evellasques <[email protected]>
  • Loading branch information
evellasques authored May 29, 2024
1 parent 8a8c453 commit 136aeee
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def load_config(hf_model_name, nemo_config):
hf_config.num_key_value_heads = nemo_config.num_query_groups
hf_config.num_local_experts = nemo_config.num_moe_experts
assert hf_config.num_local_experts > 0, "num_experts must be greater than zero."
hf_config.num_experts_per_tok = nemo_config.num_experts_per_token
hf_config.num_experts_per_tok = nemo_config.moe_router_topk
assert hf_config.num_experts_per_tok > 0, "num_experts_per_token must be greater than zero."
if nemo_config.activation == 'fast-swiglu':
hf_config.activation = 'silu'
Expand Down Expand Up @@ -122,6 +122,7 @@ def convert(in_file, precision=None) -> None:
embed_weights_base_name = f'model.language_model.embedding.word_embeddings.weight'
state_dict[hf_embed_weight_name] = param_to_weights(ckpt[embed_weights_base_name])

head_num = model.cfg.num_attention_heads
if nemo_config.num_query_groups is None or nemo_config.num_query_groups == head_num:
num_query_groups = head_num
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def convert(in_file, precision=None, cpu_only=True) -> None:
config = load_config(args.hf_model_name, nemo_config)
model = AutoModelForCausalLM.from_config(config)
model.load_state_dict(hf_state_dict, strict=True)
model.save_pretrained(args.out_file)
model.save_pretrained(args.output_path)
hf_tokenizer = AutoTokenizer.from_pretrained('bigcode/starcoder2-tokenizer')
hf_tokenizer.save_pretrained(args.output_path)
logging.info(f'HF checkpoint saved to: {args.output_path}')

0 comments on commit 136aeee

Please sign in to comment.