Skip to content

Commit

Permalink
ggml : change ggml_scale to take a float instead of tensor (ggerganov…
Browse files Browse the repository at this point in the history
…#4573)

* ggml : change ggml_scale to take a float instead of tensor

* ggml : fix CPU implementation

* tests : fix test-grad0

ggml-ci
  • Loading branch information
ggerganov authored Dec 21, 2023
1 parent 769a7bc commit afefa31
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 204 deletions.
15 changes: 3 additions & 12 deletions examples/baby-llama/baby-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,10 +575,7 @@ static struct ggml_tensor * forward(

// KQ_scaled = KQ / sqrt(n_embd/n_head)
// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));

// KQ_masked = mask_past(KQ_scaled)
// KQ_masked shape [n_past + N, N, n_head, 1]
Expand Down Expand Up @@ -844,10 +841,7 @@ static struct ggml_tensor * forward_batch(

// KQ_scaled = KQ / sqrt(n_embd/n_head)
// KQ_scaled shape [n_past + N, N, n_head, n_batch]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
assert_shape_4d(KQ_scaled, n_past + N, N, n_head, n_batch);

// KQ_masked = mask_past(KQ_scaled)
Expand Down Expand Up @@ -1131,10 +1125,7 @@ static struct ggml_tensor * forward_lora(

// KQ_scaled = KQ / sqrt(n_embd/n_head)
// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));

// KQ_masked = mask_past(KQ_scaled)
// KQ_masked shape [n_past + N, N, n_head, 1]
Expand Down
2 changes: 1 addition & 1 deletion examples/export-lora/export-lora.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ static struct ggml_cgraph * build_graph_lora(
) {
struct ggml_tensor * ab = ggml_mul_mat(ctx, lora_a, lora_b);
if (scaling != 1.0f) {
ab = ggml_scale(ctx, ab, ggml_new_f32(ctx, scaling));
ab = ggml_scale(ctx, ab, scaling);
}
struct ggml_tensor * res = ggml_add_inplace(ctx, tensor, ab);

Expand Down
42 changes: 20 additions & 22 deletions examples/finetune/finetune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ static void load_model_hparams_gguf(struct gguf_context * ctx, struct my_llama_h
float rope_freq_scale = 1.0f;
GGUF_GET_KEY(ctx, hparams->f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
GGUF_GET_KEY(ctx, hparams->rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
if (rope_freq_scale != 1.0f) {
hparams->rope_freq_scale = 1.0f / rope_freq_scale;
}
Expand Down Expand Up @@ -612,6 +612,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
const int n_rot = hparams.n_embd_head();
const int n_embd_head = hparams.n_embd_head();
const int n_embd_gqa = hparams.n_embd_gqa();

const float rms_norm_eps = hparams.f_norm_rms_eps;
const float rope_freq_base = hparams.rope_freq_base;
const float rope_freq_scale = hparams.rope_freq_scale;
Expand Down Expand Up @@ -680,10 +681,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
checkpoints.push_back(t01);
}

struct ggml_tensor * kv_scale = NULL;
if (!enable_flash_attn) {
kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
}
const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head);

for (int il = 0; il < n_layer; ++il) {
struct my_llama_layer & layer = model->layers[il];
Expand Down Expand Up @@ -781,32 +779,32 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
// make sure some tensors are not reallocated by inserting new temporary nodes depending on them
int n_leafs_before = gb->n_leafs;
int n_nodes_before = gb->n_nodes;
struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);

// output tensors
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f));
// input gradient
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
ggml_allocr_alloc(alloc, t36->grad);
// KQ_pos
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));

// make sure base model tensors data cannot be used in viewable operations
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, 1.0f));
for (int il = 0; il < n_layer; ++il) {
struct my_llama_layer & layer = model->layers[il];
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, 1.0f));
}

// allocating checkpoints in one block to reduce memory fragmentation
Expand Down
8 changes: 1 addition & 7 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,6 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
ggml_repeat(ctx0, model.pre_ln_b, embeddings));
}

struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(ctx->alloc, KQ_scale);
if (!ggml_allocr_is_measure(ctx->alloc)) {
ggml_set_f32(KQ_scale, 1.0f / sqrt((float)d_head));
}

// loop over layers
for (int il = 0; il < n_layer - 1; il++) {
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
Expand All @@ -356,7 +350,7 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
struct ggml_tensor * Q =
ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, cur), ggml_mul_mat(ctx0, model.layers[il].q_w, cur));

Q = ggml_scale_inplace(ctx0, Q, KQ_scale);
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
Expand Down
14 changes: 5 additions & 9 deletions examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,7 @@ static struct ggml_tensor * llama_build_train_graphs(
checkpoints.push_back(t00);
checkpoints.push_back(t01);

struct ggml_tensor * kv_scale = NULL;
if (!enable_flash_attn) {
kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
}
const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head);

for (int il = 0; il < n_layer; ++il) {
struct my_llama_layer & layer = model->layers[il];
Expand Down Expand Up @@ -444,14 +441,13 @@ static struct ggml_tensor * llama_build_train_graphs(
// make sure some tensors are not reallocated by inserting new temporary nodes depending on them
int n_leafs_before = gb->n_leafs;
int n_nodes_before = gb->n_nodes;
struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
// output tensors
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f));
// input gradient
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
// KQ_pos
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);

ggml_allocr_alloc(alloc, t36->grad);
Expand Down
14 changes: 2 additions & 12 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7700,17 +7700,9 @@ inline void ggml_cuda_op_scale(
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

float scale;
// HACK: support for ggml backend interface
if (src1->backend == GGML_BACKEND_CPU) {
scale = ((float *) src1->data)[0];
} else {
// TODO: pass pointer to kernel instead of copying to host
CUDA_CHECK(cudaMemcpy(&scale, src1->data, sizeof(float), cudaMemcpyDeviceToHost));
}
const float scale = ((float *) dst->op_params)[0];

scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
CUDA_CHECK(cudaGetLastError());
Expand Down Expand Up @@ -7757,8 +7749,6 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU;
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU;

const bool src1_stays_on_host = use_src1 && dst->op == GGML_OP_SCALE;

// dd = data device
float * src0_ddf = nullptr;
float * src1_ddf = nullptr;
Expand All @@ -7779,7 +7769,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
}

if (use_src1 && !src1_stays_on_host) {
if (use_src1) {
if (src1_on_device) {
src1_ddf = (float *) src1_extra->data_device[g_main_device];
} else {
Expand Down
6 changes: 3 additions & 3 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,7 @@ void ggml_metal_graph_compute(
{
GGML_ASSERT(ggml_is_contiguous(src0));

const float scale = *(const float *) src1->data;
const float scale = *(const float *) dst->op_params;

int64_t n = ggml_nelements(dst);

Expand All @@ -1304,8 +1304,8 @@ void ggml_metal_graph_compute(
[encoder setComputePipelineState:ctx->pipeline_scale];
}

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
Expand Down
42 changes: 17 additions & 25 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4171,39 +4171,39 @@ struct ggml_tensor * ggml_out_prod(
static struct ggml_tensor * ggml_scale_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float s,
bool inplace) {
GGML_ASSERT(ggml_is_scalar(b));
GGML_ASSERT(ggml_is_padded_1d(a));

bool is_node = false;

if (a->grad || b->grad) {
if (a->grad) {
is_node = true;
}

struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

ggml_set_op_params(result, &s, sizeof(s));

result->op = GGML_OP_SCALE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;

return result;
}

struct ggml_tensor * ggml_scale(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_scale_impl(ctx, a, b, false);
float s) {
return ggml_scale_impl(ctx, a, s, false);
}

struct ggml_tensor * ggml_scale_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_scale_impl(ctx, a, b, true);
float s) {
return ggml_scale_impl(ctx, a, s, true);
}

// ggml_set
Expand Down Expand Up @@ -10325,19 +10325,17 @@ static void ggml_compute_forward_out_prod(
static void ggml_compute_forward_scale_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_scalar(src1));

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}

// scale factor
const float v = *(float *) src1->data;
const float v = *(float *) dst->op_params;

const int ith = params->ith;
const int nth = params->nth;
Expand Down Expand Up @@ -10368,12 +10366,11 @@ static void ggml_compute_forward_scale_f32(
static void ggml_compute_forward_scale(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_scale_f32(params, src0, src1, dst);
ggml_compute_forward_scale_f32(params, src0, dst);
} break;
default:
{
Expand Down Expand Up @@ -14383,7 +14380,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_SCALE:
{
ggml_compute_forward_scale(params, tensor->src[0], tensor->src[1], tensor);
ggml_compute_forward_scale(params, tensor->src[0], tensor);
} break;
case GGML_OP_SET:
{
Expand Down Expand Up @@ -14839,7 +14836,7 @@ static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct gg

static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set zero_table) {
if (ggml_hash_contains(zero_table, a)) {
struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
} else {
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
Expand Down Expand Up @@ -14975,7 +14972,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad,
ggml_scale(ctx,
ggml_mul(ctx, src0, tensor->grad),
ggml_new_f32(ctx, 2.0f)),
2.0f),
zero_table);
}
} break;
Expand All @@ -14989,7 +14986,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
ggml_div(ctx,
tensor->grad,
tensor),
ggml_new_f32(ctx, 0.5f)),
0.5f),
zero_table);
}
} break;
Expand Down Expand Up @@ -15155,17 +15152,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
// necessary for llama
if (src0->grad) {
const float s = ((float *) tensor->op_params)[0];

src0->grad =
ggml_add_or_set(ctx,
src0->grad,
ggml_scale_impl(ctx, tensor->grad, src1, false),
zero_table);
}
if (src1->grad) {
src1->grad =
ggml_add_or_set(ctx,
src1->grad,
ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
ggml_scale_impl(ctx, tensor->grad, s, false),
zero_table);
}
} break;
Expand Down
4 changes: 2 additions & 2 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1094,13 +1094,13 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_scale(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
float s);

// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_scale_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
float s);

// b -> view(a,offset,nb1,nb2,3), return modified a
GGML_API struct ggml_tensor * ggml_set(
Expand Down
Loading

0 comments on commit afefa31

Please sign in to comment.