Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Phi-2 #4490

Merged
merged 17 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def from_model_architecture(model_architecture):
return QwenModel
if model_architecture == "MixtralForCausalLM":
return MixtralModel
if model_architecture == "PhiForCausalLM":
return Phi2Model
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -221,6 +223,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.QWEN
if arch == "MixtralForCausalLM":
return gguf.MODEL_ARCH.LLAMA
if arch == "PhiForCausalLM":
return gguf.MODEL_ARCH.PHI2

raise NotImplementedError(f'Architecture "{arch}" not supported!')

Expand All @@ -245,7 +249,7 @@ def _set_vocab_gpt2(self):
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
if hasattr(tokenizer, "added_tokens_decoder") and tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
Expand Down Expand Up @@ -980,6 +984,24 @@ def write_tensors(self):
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)


class Phi2Model(Model):
def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]

self.gguf_writer.add_name("Phi2")
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_head_count_kv(self.hparams["n_head"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_add_bos_token(False)


###### CONVERSION LOGIC ######


Expand Down
11 changes: 10 additions & 1 deletion ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4998,7 +4998,16 @@ static __global__ void rope_neox(
const int ib = col / n_dims;
const int ic = col % n_dims;

const int i = row*ncols + ib*n_dims + ic/2;
// IMPORTANT: consider the case ncols == 80 and n_dims == 32 (phi-2)
// I don't know what we are supposed to compute, because the row is not divisible by n_dims
// this check matches the CPU code, but it is likely wrong as well
// I can't understand the Python code, so if you know what to do here, please fix it
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
if (ncols % n_dims != 0 && ib == ncols/n_dims) {
return;
}
ggerganov marked this conversation as resolved.
Show resolved Hide resolved

const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows;

float cur_rot = inv_ndims * ic - ib;
Expand Down
4 changes: 4 additions & 0 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -9168,6 +9168,8 @@ static void ggml_compute_forward_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

GGML_ASSERT(eps > 0.0f);

// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
Expand Down Expand Up @@ -9237,6 +9239,8 @@ static void ggml_compute_forward_rms_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

GGML_ASSERT(eps > 0.0f);

// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
Expand Down
13 changes: 13 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class MODEL_ARCH(IntEnum):
BLOOM = auto()
STABLELM = auto()
QWEN = auto()
PHI2 = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -140,6 +141,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm",
MODEL_ARCH.QWEN: "qwen",
MODEL_ARCH.PHI2: "phi2",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -350,6 +352,17 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GPT2: [
# TODO
],
MODEL_ARCH.PHI2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
]
# TODO
}

Expand Down
8 changes: 8 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TensorNameMap:
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert
"language_model.embedding.word_embeddings", # persimmon
"transformer.embd.wte", # phi2
),

# Token type embeddings
Expand All @@ -41,6 +42,7 @@ class TensorNameMap:
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen
"output", # llama-pth bloom
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
),

# Output norm
Expand All @@ -53,6 +55,7 @@ class TensorNameMap:
"transformer.norm_f", # mpt
"ln_f", # refact bloom qwen
"language_model.encoder.final_layernorm", # persimmon
"lm_head.ln", # phi2
),

# Rope frequencies
Expand All @@ -75,6 +78,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi
"transformer.h.{bid}.ln", # phi2
),

# Attention norm 2
Expand All @@ -90,6 +94,7 @@ class TensorNameMap:
"transformer.h.{bid}.self_attention.query_key_value", # falcon
"h.{bid}.self_attention.query_key_value", # bloom
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
"transformer.h.{bid}.mixer.Wqkv", # phi2
),

# Attention query
Expand Down Expand Up @@ -128,6 +133,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"transformer.h.{bid}.mixer.out_proj", # phi2
),

# Rotary embeddings
Expand Down Expand Up @@ -167,6 +173,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.fc_in", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
"transformer.h.{bid}.mlp.w1", # qwen
"transformer.h.{bid}.mlp.fc1", # phi2
),

MODEL_TENSOR.FFN_UP_EXP: (
Expand Down Expand Up @@ -198,6 +205,7 @@ class TensorNameMap:
"encoder.layer.{bid}.output.dense", # bert
"transformer.h.{bid}.mlp.fc_out", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
"transformer.h.{bid}.mlp.fc2", # phi2
),

MODEL_TENSOR.FFN_DOWN_EXP: (
Expand Down
Loading
Loading