Skip to content

Commit

Permalink
ggml : add ggml_row_size() (fixes llama out of space) (ggerganov#4461)
Browse files Browse the repository at this point in the history
* Fixes "Not enough space in the context's memory pool" encountered on certain models, which seems to be caused by some imprecision related to the automatic casting of floating point values

* do not cast to size_t, instead just use doubles

* ggml : add ggml_row_size(), deprecate ggml_type_sizef()

* ggml : fix row size compute to avoid overflows

* tests : fix sizey -> sizez

---------

Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
2 people authored and teleprint-me committed Dec 21, 2023
1 parent 2712de8 commit 5857dcb
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 18 deletions.
14 changes: 7 additions & 7 deletions examples/benchmark/benchmark-matmult.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ int main(int argc, char ** argv) {
const ggml_type qtype = GGML_TYPE_Q4_1;

size_t ctx_size = 0;
ctx_size += sizex*sizey*ggml_type_sizef(GGML_TYPE_F32);
ctx_size += sizex*sizey*ggml_type_sizef(GGML_TYPE_F32);
ctx_size += sizex*sizez*ggml_type_sizef(GGML_TYPE_F32);
ctx_size += sizex*sizey*ggml_type_sizef(qtype);
ctx_size += sizex*sizey*ggml_type_sizef(qtype);
ctx_size += sizex*sizey*ggml_type_sizef(GGML_TYPE_F32); // BLAS
ctx_size += sizex*sizey*ggml_type_sizef(GGML_TYPE_F32); // BLAS
ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey);
ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey);
ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizez);
ctx_size += ggml_row_size(qtype, sizex*sizey);
ctx_size += ggml_row_size(qtype, sizex*sizey);
ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey); // BLAS
ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey); // BLAS
ctx_size += 1024*1024*16;

printf("Allocating Memory of size %zi bytes, %zi MB\n",ctx_size, (ctx_size/1024/1024));
Expand Down
9 changes: 7 additions & 2 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2013,8 +2013,13 @@ size_t ggml_type_size(enum ggml_type type) {
return type_traits[type].type_size;
}

float ggml_type_sizef(enum ggml_type type) {
return ((float)(type_traits[type].type_size))/type_traits[type].blck_size;
size_t ggml_row_size(enum ggml_type type, int64_t ne) {
assert(ne % ggml_blck_size(type) == 0);
return ggml_type_size(type)*ne/ggml_blck_size(type);
}

double ggml_type_sizef(enum ggml_type type) {
return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;
}

const char * ggml_type_name(enum ggml_type type) {
Expand Down
10 changes: 7 additions & 3 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -641,9 +641,13 @@ extern "C" {
GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
GGML_API size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split);

GGML_API int ggml_blck_size (enum ggml_type type);
GGML_API size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
GGML_API float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
GGML_API int ggml_blck_size(enum ggml_type type);
GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row

GGML_DEPRECATED(
GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
"use ggml_row_size() instead");

GGML_API const char * ggml_type_name(enum ggml_type type);
GGML_API const char * ggml_op_name (enum ggml_op op);
Expand Down
12 changes: 6 additions & 6 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1573,7 +1573,7 @@ static bool llama_kv_cache_init(
cache.cells.clear();
cache.cells.resize(n_ctx);

cache.buf.resize(n_elements*(ggml_type_sizef(ktype) + ggml_type_sizef(vtype)) + 2u*n_layer*ggml_tensor_overhead());
cache.buf.resize(ggml_row_size(ktype, n_elements) + ggml_row_size(vtype, n_elements) + 2u*n_layer*ggml_tensor_overhead());
memset(cache.buf.data, 0, cache.buf.size);

struct ggml_init_params params;
Expand Down Expand Up @@ -3856,8 +3856,8 @@ static void llm_build_k_shift(
ggml_rope_custom_inplace(ctx,
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head, n_head_kv, n_ctx,
ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
ggml_row_size(kv.k_l[il]->type, n_embd_head),
ggml_row_size(kv.k_l[il]->type, n_embd_gqa),
0),
K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Expand Down Expand Up @@ -3886,7 +3886,7 @@ static void llm_build_kv_store(
cb(v_cur_t, "v_cur_t", il);

struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_gqa,
(ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa)*kv_head);
(ggml_row_size(kv.k_l[il]->type, n_embd_gqa))*kv_head);
cb(k_cache_view, "k_cache_view", il);

struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_gqa,
Expand Down Expand Up @@ -4045,8 +4045,8 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * k =
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head, n_kv, n_head_kv,
ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
ggml_row_size(kv.k_l[il]->type, n_embd_gqa),
ggml_row_size(kv.k_l[il]->type, n_embd_head),
0);
cb(k, "k", il);

Expand Down

0 comments on commit 5857dcb

Please sign in to comment.