Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : change ggml_scale to take a float instead of tensor #4573

Merged
merged 4 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions examples/baby-llama/baby-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,10 +575,7 @@ static struct ggml_tensor * forward(

// KQ_scaled = KQ / sqrt(n_embd/n_head)
// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));

// KQ_masked = mask_past(KQ_scaled)
// KQ_masked shape [n_past + N, N, n_head, 1]
Expand Down Expand Up @@ -844,10 +841,7 @@ static struct ggml_tensor * forward_batch(

// KQ_scaled = KQ / sqrt(n_embd/n_head)
// KQ_scaled shape [n_past + N, N, n_head, n_batch]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
assert_shape_4d(KQ_scaled, n_past + N, N, n_head, n_batch);

// KQ_masked = mask_past(KQ_scaled)
Expand Down Expand Up @@ -1131,10 +1125,7 @@ static struct ggml_tensor * forward_lora(

// KQ_scaled = KQ / sqrt(n_embd/n_head)
// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));

// KQ_masked = mask_past(KQ_scaled)
// KQ_masked shape [n_past + N, N, n_head, 1]
Expand Down
2 changes: 1 addition & 1 deletion examples/export-lora/export-lora.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ static struct ggml_cgraph * build_graph_lora(
) {
struct ggml_tensor * ab = ggml_mul_mat(ctx, lora_a, lora_b);
if (scaling != 1.0f) {
ab = ggml_scale(ctx, ab, ggml_new_f32(ctx, scaling));
ab = ggml_scale(ctx, ab, scaling);
}
struct ggml_tensor * res = ggml_add_inplace(ctx, tensor, ab);

Expand Down
42 changes: 20 additions & 22 deletions examples/finetune/finetune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ static void load_model_hparams_gguf(struct gguf_context * ctx, struct my_llama_h
float rope_freq_scale = 1.0f;
GGUF_GET_KEY(ctx, hparams->f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
GGUF_GET_KEY(ctx, hparams->rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
if (rope_freq_scale != 1.0f) {
hparams->rope_freq_scale = 1.0f / rope_freq_scale;
}
Expand Down Expand Up @@ -612,6 +612,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
const int n_rot = hparams.n_embd_head();
const int n_embd_head = hparams.n_embd_head();
const int n_embd_gqa = hparams.n_embd_gqa();

const float rms_norm_eps = hparams.f_norm_rms_eps;
const float rope_freq_base = hparams.rope_freq_base;
const float rope_freq_scale = hparams.rope_freq_scale;
Expand Down Expand Up @@ -680,10 +681,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
checkpoints.push_back(t01);
}

struct ggml_tensor * kv_scale = NULL;
if (!enable_flash_attn) {
kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
}
const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head);

for (int il = 0; il < n_layer; ++il) {
struct my_llama_layer & layer = model->layers[il];
Expand Down Expand Up @@ -781,32 +779,32 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
// make sure some tensors are not reallocated by inserting new temporary nodes depending on them
int n_leafs_before = gb->n_leafs;
int n_nodes_before = gb->n_nodes;
struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);

// output tensors
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f));
// input gradient
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
ggml_allocr_alloc(alloc, t36->grad);
// KQ_pos
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));

// make sure base model tensors data cannot be used in viewable operations
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, 1.0f));
for (int il = 0; il < n_layer; ++il) {
struct my_llama_layer & layer = model->layers[il];
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, 1.0f));
}

// allocating checkpoints in one block to reduce memory fragmentation
Expand Down
8 changes: 1 addition & 7 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,6 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
ggml_repeat(ctx0, model.pre_ln_b, embeddings));
}

struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(ctx->alloc, KQ_scale);
if (!ggml_allocr_is_measure(ctx->alloc)) {
ggml_set_f32(KQ_scale, 1.0f / sqrt((float)d_head));
}

// loop over layers
for (int il = 0; il < n_layer - 1; il++) {
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
Expand All @@ -356,7 +350,7 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
struct ggml_tensor * Q =
ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, cur), ggml_mul_mat(ctx0, model.layers[il].q_w, cur));

Q = ggml_scale_inplace(ctx0, Q, KQ_scale);
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
Expand Down
14 changes: 5 additions & 9 deletions examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,7 @@ static struct ggml_tensor * llama_build_train_graphs(
checkpoints.push_back(t00);
checkpoints.push_back(t01);

struct ggml_tensor * kv_scale = NULL;
if (!enable_flash_attn) {
kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
}
const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head);

for (int il = 0; il < n_layer; ++il) {
struct my_llama_layer & layer = model->layers[il];
Expand Down Expand Up @@ -444,14 +441,13 @@ static struct ggml_tensor * llama_build_train_graphs(
// make sure some tensors are not reallocated by inserting new temporary nodes depending on them
int n_leafs_before = gb->n_leafs;
int n_nodes_before = gb->n_nodes;
struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
// output tensors
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f));
// input gradient
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
// KQ_pos
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);

ggml_allocr_alloc(alloc, t36->grad);
Expand Down
14 changes: 2 additions & 12 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7700,17 +7700,9 @@ inline void ggml_cuda_op_scale(
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

float scale;
// HACK: support for ggml backend interface
if (src1->backend == GGML_BACKEND_CPU) {
scale = ((float *) src1->data)[0];
} else {
// TODO: pass pointer to kernel instead of copying to host
CUDA_CHECK(cudaMemcpy(&scale, src1->data, sizeof(float), cudaMemcpyDeviceToHost));
}
const float scale = ((float *) dst->op_params)[0];

scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
CUDA_CHECK(cudaGetLastError());
Expand Down Expand Up @@ -7757,8 +7749,6 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU;
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU;

const bool src1_stays_on_host = use_src1 && dst->op == GGML_OP_SCALE;

// dd = data device
float * src0_ddf = nullptr;
float * src1_ddf = nullptr;
Expand All @@ -7779,7 +7769,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
}

if (use_src1 && !src1_stays_on_host) {
if (use_src1) {
if (src1_on_device) {
src1_ddf = (float *) src1_extra->data_device[g_main_device];
} else {
Expand Down
6 changes: 3 additions & 3 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,7 @@ void ggml_metal_graph_compute(
{
GGML_ASSERT(ggml_is_contiguous(src0));

const float scale = *(const float *) src1->data;
const float scale = *(const float *) dst->op_params;

int64_t n = ggml_nelements(dst);

Expand All @@ -1304,8 +1304,8 @@ void ggml_metal_graph_compute(
[encoder setComputePipelineState:ctx->pipeline_scale];
}

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
Expand Down
42 changes: 17 additions & 25 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4171,39 +4171,39 @@ struct ggml_tensor * ggml_out_prod(
static struct ggml_tensor * ggml_scale_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float s,
bool inplace) {
GGML_ASSERT(ggml_is_scalar(b));
GGML_ASSERT(ggml_is_padded_1d(a));

bool is_node = false;

if (a->grad || b->grad) {
if (a->grad) {
is_node = true;
}

struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

ggml_set_op_params(result, &s, sizeof(s));

result->op = GGML_OP_SCALE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;

return result;
}

struct ggml_tensor * ggml_scale(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_scale_impl(ctx, a, b, false);
float s) {
return ggml_scale_impl(ctx, a, s, false);
}

struct ggml_tensor * ggml_scale_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_scale_impl(ctx, a, b, true);
float s) {
return ggml_scale_impl(ctx, a, s, true);
}

// ggml_set
Expand Down Expand Up @@ -10325,19 +10325,17 @@ static void ggml_compute_forward_out_prod(
static void ggml_compute_forward_scale_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_scalar(src1));

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}

// scale factor
const float v = *(float *) src1->data;
const float v = *(float *) dst->op_params;

const int ith = params->ith;
const int nth = params->nth;
Expand Down Expand Up @@ -10368,12 +10366,11 @@ static void ggml_compute_forward_scale_f32(
static void ggml_compute_forward_scale(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_scale_f32(params, src0, src1, dst);
ggml_compute_forward_scale_f32(params, src0, dst);
} break;
default:
{
Expand Down Expand Up @@ -14383,7 +14380,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_SCALE:
{
ggml_compute_forward_scale(params, tensor->src[0], tensor->src[1], tensor);
ggml_compute_forward_scale(params, tensor->src[0], tensor);
} break;
case GGML_OP_SET:
{
Expand Down Expand Up @@ -14839,7 +14836,7 @@ static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct gg

static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set zero_table) {
if (ggml_hash_contains(zero_table, a)) {
struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
} else {
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
Expand Down Expand Up @@ -14975,7 +14972,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad,
ggml_scale(ctx,
ggml_mul(ctx, src0, tensor->grad),
ggml_new_f32(ctx, 2.0f)),
2.0f),
zero_table);
}
} break;
Expand All @@ -14989,7 +14986,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_div(ctx,
tensor->grad,
tensor),
ggml_new_f32(ctx, 0.5f)),
0.5f),
zero_table);
}
} break;
Expand Down Expand Up @@ -15155,17 +15152,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
// necessary for llama
if (src0->grad) {
const float s = ((float *) tensor->op_params)[0];

src0->grad =
ggml_add_or_set(ctx,
src0->grad,
ggml_scale_impl(ctx, tensor->grad, src1, false),
zero_table);
}
if (src1->grad) {
src1->grad =
ggml_add_or_set(ctx,
src1->grad,
ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
ggml_scale_impl(ctx, tensor->grad, s, false),
zero_table);
}
} break;
Expand Down
4 changes: 2 additions & 2 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1094,13 +1094,13 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_scale(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
float s);

// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_scale_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
float s);

// b -> view(a,offset,nb1,nb2,3), return modified a
GGML_API struct ggml_tensor * ggml_set(
Expand Down
Loading
Loading