Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Phi-2 #4490

Merged
merged 17 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def from_model_architecture(model_architecture):
return QwenModel
if model_architecture == "MixtralForCausalLM":
return MixtralModel
if model_architecture == "PhiForCausalLM":
return Phi2Model
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -221,6 +223,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.QWEN
if arch == "MixtralForCausalLM":
return gguf.MODEL_ARCH.LLAMA
if arch == "PhiForCausalLM":
return gguf.MODEL_ARCH.PHI2

raise NotImplementedError(f'Architecture "{arch}" not supported!')

Expand Down Expand Up @@ -980,6 +984,24 @@ def write_tensors(self):
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)


class Phi2Model(Model):
def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]

self.gguf_writer.add_name("Phi2")
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_head_count_kv(self.hparams["n_head"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_add_bos_token(False)


###### CONVERSION LOGIC ######


Expand Down
115 changes: 80 additions & 35 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4998,7 +4998,16 @@ static __global__ void rope_neox(
const int ib = col / n_dims;
const int ic = col % n_dims;

const int i = row*ncols + ib*n_dims + ic/2;
if (ib > 0) {
const int i = row*ncols + ib*n_dims + ic;

dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];

return;
}

const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows;

float cur_rot = inv_ndims * ic - ib;
Expand Down Expand Up @@ -7367,6 +7376,8 @@ inline void ggml_cuda_op_mul_mat_cublas(

const int64_t row_diff = row_high - row_low;

// TODO: handle dst->op_params[0] != GGML_PREC_DEFAULT

int id;
CUDA_CHECK(cudaGetDevice(&id));

Expand Down Expand Up @@ -8300,27 +8311,27 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
}

static __global__ void k_compute_batched_ptrs(
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
const half * src0_as_f16, const half * src1_as_f16, char * dst,
const void ** ptrs_src, void ** ptrs_dst,
int ne12, int ne13,
int ne23,
int nb02, int nb03,
int nb12, int nb13,
int nb2, int nb3,
int r2, int r3) {
int i13 = blockIdx.x * blockDim.x + threadIdx.x;
int i12 = blockIdx.y * blockDim.y + threadIdx.y;
int64_t ne12, int64_t ne13,
int64_t ne23,
size_t nb02, size_t nb03,
size_t nb12, size_t nb13,
size_t nbd2, size_t nbd3,
int64_t r2, int64_t r3) {
int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;

if (i13 >= ne13 || i12 >= ne12) {
return;
}

int i03 = i13 / r3;
int i02 = i12 / r2;
int64_t i03 = i13 / r3;
int64_t i02 = i12 / r2;

ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
}

static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
Expand Down Expand Up @@ -8376,7 +8387,41 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);

size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);

half * dst_f16 = nullptr;
char * dst_t = nullptr;

cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
cudaDataType_t cu_data_type = CUDA_R_16F;

// dst strides
size_t nbd2 = dst->nb[2];
size_t nbd3 = dst->nb[3];

const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;

const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;

const char * alpha = (const char *) &alpha_f16;
const char * beta = (const char *) &beta_f16;
ggerganov marked this conversation as resolved.
Show resolved Hide resolved

if (dst->op_params[0] == GGML_PREC_DEFAULT) {
dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
dst_t = (char *) dst_f16;

nbd2 /= sizeof(float) / sizeof(half);
nbd3 /= sizeof(float) / sizeof(half);
} else {
dst_t = (char *) dst_ddf;

cu_compute_type = CUBLAS_COMPUTE_32F;
cu_data_type = CUDA_R_32F;

alpha = (const char *) &alpha_f32;
beta = (const char *) &beta_f32;
ggerganov marked this conversation as resolved.
Show resolved Hide resolved
}

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
Expand All @@ -8385,9 +8430,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;

const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;

#if 0
// use cublasGemmEx
{
Expand All @@ -8397,12 +8439,12 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
int i02 = i12 / r2;

CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
CUBLAS_COMPUTE_16F,
alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
}
Expand All @@ -8414,11 +8456,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
ne12*ne13,
CUBLAS_COMPUTE_16F,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// use cublasGemmBatchedEx
Expand All @@ -8435,24 +8477,24 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const

dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
src0_as_f16, src1_as_f16, dst_f16,
src0_as_f16, src1_as_f16, dst_t,
ptrs_src, ptrs_dst,
ne12, ne13,
ne23,
nb02, nb03,
nb12, nb13,
dst->nb[2], dst->nb[3],
nbd2, nbd3,
r2, r3);
CUDA_CHECK(cudaGetLastError());

CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne01,
ne23,
CUBLAS_COMPUTE_16F,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

if (ptrs_src_s != 0) {
Expand All @@ -8464,11 +8506,14 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
}
#endif

const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);

ggml_cuda_pool_free(dst_f16, dst_as);
}

ggml_cuda_pool_free(src1_as_f16, src1_as);
ggml_cuda_pool_free(dst_f16, dst_as);
}

static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
Expand Down
13 changes: 11 additions & 2 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1702,8 +1702,9 @@ kernel void kernel_rope(
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
} else {
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
if (ic < n_dims) {
const int64_t ib = 0;

// simplified from `(ib * n_dims + ic) * inv_ndims`
const float cur_rot = inv_ndims*ic - ib;
Expand All @@ -1722,6 +1723,14 @@ kernel void kernel_rope(

dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const int64_t i0 = ic;

device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
Expand Down
46 changes: 40 additions & 6 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4098,6 +4098,14 @@ struct ggml_tensor * ggml_mul_mat(
return result;
}

void ggml_mul_mat_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec) {
const int32_t prec_i32 = (int32_t) prec;

ggml_set_op_params_i32(a, 0, prec_i32);
}

// ggml_mul_mat_id

struct ggml_tensor * ggml_mul_mat_id(
Expand Down Expand Up @@ -9168,6 +9176,8 @@ static void ggml_compute_forward_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

GGML_ASSERT(eps > 0.0f);

// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
Expand Down Expand Up @@ -9237,6 +9247,8 @@ static void ggml_compute_forward_rms_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

GGML_ASSERT(eps > 0.0f);

// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
Expand Down Expand Up @@ -11562,10 +11574,13 @@ static void ggml_compute_forward_rope_f32(
}
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
// it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
const int64_t ib = 0;

// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;

Expand All @@ -11588,6 +11603,14 @@ static void ggml_compute_forward_rope_f32(

dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const int64_t i0 = ic;

const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
Expand Down Expand Up @@ -11715,10 +11738,13 @@ static void ggml_compute_forward_rope_f16(
}
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
// it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
const int64_t ib = 0;

// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;

Expand All @@ -11741,6 +11767,14 @@ static void ggml_compute_forward_rope_f16(

dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
} else {
const int64_t i0 = ic;

const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
Expand Down
12 changes: 12 additions & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,12 @@ extern "C" {
GGML_TYPE_COUNT,
};

// precision
enum ggml_prec {
GGML_PREC_DEFAULT,
GGML_PREC_F32,
};

enum ggml_backend_type {
GGML_BACKEND_CPU = 0,
GGML_BACKEND_GPU = 10,
Expand Down Expand Up @@ -1057,6 +1063,12 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);

// change the precision of a matrix multiplication
// set to GGML_PREC_F32 for higher precision (useful for phi-2)
GGML_API void ggml_mul_mat_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec);

// indirect matrix multiplication
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
GGML_API struct ggml_tensor * ggml_mul_mat_id(
Expand Down
Loading
Loading