Skip to content

Commit

Permalink
Add support for GGUF Phi-3 (huggingface#31844)
Browse files Browse the repository at this point in the history
* Update docs for GGUF supported models

* Add tensor mappings and define class GGUFPhi3Converter

* Fix tokenizer

* Working version

* Attempt to fix some CI failures

* Run ruff format

* Add vocab, merges, decoder methods like LlamaConverter

* Resolve conflicts since Qwen2Moe was added to gguf

- I missed one place when resolving conflict
- I also made a mistake with tests_ggml.py and now has been fixed to reflect
its master version.
  • Loading branch information
a8nova authored and dataKim1201 committed Oct 7, 2024
1 parent 0274e22 commit d8eb6af
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ For now the supported model architectures are the architectures that have been v
- Mistral
- Qwen2
- Qwen2Moe
- Phi3

## Example usage

Expand Down
1 change: 1 addition & 0 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,6 +1575,7 @@ def converted(self) -> Tokenizer:
"LlamaTokenizer": LlamaConverter,
"CodeLlamaTokenizer": LlamaConverter,
"GemmaTokenizer": GemmaConvert,
"Phi3Tokenizer": LlamaConverter,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
logger = logging.get_logger(__name__)


TOKENIZER_CLASSES = {name: getattr(transformers, name + "Fast") for name in SLOW_TO_FAST_CONVERTERS}
TOKENIZER_CLASSES = {
# Phi3 uses Llama tokenizer
name: getattr(transformers, "LlamaTokenizerFast" if name == "Phi3Tokenizer" else name + "Fast")
for name in SLOW_TO_FAST_CONVERTERS
}


def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path, force_download):
Expand Down
101 changes: 101 additions & 0 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
"phi3": {
"token_embd": "model.embed_tokens",
"blk": "model.layers",
"ffn_up": "mlp.gate_up_proj",
"ffn_down": "mlp.down_proj",
"ffn_gate": "mlp.gate_up_proj",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "input_layernorm",
"attn_qkv": "self_attn.qkv_proj",
"attn_output": "self_attn.o_proj",
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
}


Expand Down Expand Up @@ -156,6 +169,18 @@
"ggml.unknown_token_id": "unk_token_id",
"ggml.padding_token_id": "pad_token_id",
},
"phi3": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
}

GGUF_TOKENIZER_MAPPING = {
Expand Down Expand Up @@ -390,10 +415,86 @@ def converted(self) -> Tokenizer:
return tokenizer


class GGUFPhi3Converter(LlamaConverter):
def __init__(self, tokenizer_dict):
self.proto = GGUFTokenizerSkeleton(tokenizer_dict)
self.original_tokenizer = self.proto
self.additional_kwargs = {}

def vocab(self, proto):
return list(zip(proto.tokens, proto.scores))

def merges(self, proto):
return proto.merges

def tokenizer(self, proto):
vocab_scores = self.vocab(self.proto)
merges = self.merges(self.proto)
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}

tokenizer = Tokenizer(BPE(bpe_vocab, merges))
# add the special tokens from phi3 tokenizer config
tokenizer.add_special_tokens(
[
AddedToken("</s>", rstrip=True, lstrip=False, normalized=False, special=True),
AddedToken("<|endoftext|>", normalized=False, special=True),
AddedToken("<|assistant|>", rstrip=True, normalized=False, special=True),
AddedToken("<|placeholder1|>", rstrip=True, normalized=False, special=True),
AddedToken("<|placeholder2|>", rstrip=True, normalized=False, special=True),
AddedToken("<|placeholder3|>", rstrip=True, normalized=False, special=True),
AddedToken("<|placeholder4|>", rstrip=True, normalized=False, special=True),
AddedToken("<|system|>", rstrip=True, normalized=False, special=True),
AddedToken("<|end|>", rstrip=True, normalized=False, special=True),
AddedToken("<|placeholder5|>", rstrip=True, normalized=False, special=True),
AddedToken("<|placeholder6|>", rstrip=True, normalized=False, special=True),
AddedToken("<|user|>", rstrip=True, normalized=False, special=True),
]
)

self.additional_kwargs["unk_token"] = (
proto.tokens[proto.unk_token_id] if proto.unk_token_id is not None else None
)
self.additional_kwargs["eos_token"] = (
proto.tokens[proto.eos_token_id] if proto.eos_token_id is not None else None
)
self.additional_kwargs["bos_token"] = (
proto.tokens[proto.bos_token_id] if proto.bos_token_id is not None else None
)
self.additional_kwargs["pad_token"] = (
proto.tokens[proto.pad_token_id] if proto.pad_token_id is not None else None
)

return tokenizer

def decoder(self, replacement, add_prefix_space):
sequence = [
decoders.ByteFallback(),
decoders.Fuse(),
decoders.Replace(replacement, " "),
]

if add_prefix_space:
sequence += [decoders.Strip(content=" ", left=1)]
return decoders.Sequence(sequence)

def converted(self) -> Tokenizer:
tokenizer = self.tokenizer(self.proto)

replacement = "▁"
add_prefix_space = True
if hasattr(self.original_tokenizer, "add_prefix_space"):
add_prefix_space = self.original_tokenizer.add_prefix_space

tokenizer.decoder = self.decoder(replacement, add_prefix_space)

return tokenizer


GGUF_TO_FAST_CONVERTERS = {
"llama": GGUFLlamaConverter,
"qwen2": GGUFQwen2Converter,
"qwen2_moe": GGUFQwen2Converter,
"phi3": GGUFPhi3Converter,
}


Expand Down
14 changes: 14 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class GgufIntegrationTests(unittest.TestCase):
qwen2_moe_model_id = "RichardErkhov/Qwen_-_Qwen1.5-MoE-A2.7B-Chat-gguf"
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
phi3_model_id = "microsoft/Phi-3-mini-4k-instruct-gguf"

# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
Expand All @@ -63,6 +64,7 @@ class GgufIntegrationTests(unittest.TestCase):
iq4_xs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf"
iq4_nl_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ4_NL.gguf"

q4_0_phi3_model_id = "Phi-3-mini-4k-instruct-q4.gguf"
q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf"
Expand Down Expand Up @@ -347,6 +349,18 @@ def test_qwen2_moe_q4_0(self):
EXPECTED_TEXT = "Hello everyone, I'm a newbie here and would like"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_phi3_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.phi3_model_id, gguf_file=self.q4_0_phi3_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.phi3_model_id, gguf_file=self.q4_0_phi3_model_id, device_map="auto", torch_dtype=torch.float16
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello, I've been reading about the impact of"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_llama3_q4_0_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
with tempfile.TemporaryDirectory() as tmpdirname:
Expand Down

0 comments on commit d8eb6af

Please sign in to comment.