From 857e7355ff1ba847bef77b12413c49ac4023092f Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 21 Dec 2023 21:07:46 +0100 Subject: [PATCH] llama : initial ggml-backend integration (#4520) * llama : initial ggml-backend integration * add ggml-metal * cuda backend can be used though ggml-backend with LLAMA_GGML_BACKEND_CUDA_TEST access all tensor data with ggml_backend_tensor_get/set * add ggml_backend_buffer_clear zero-init KV cache buffer * add ggml_backend_buffer_is_hos, used to avoid copies if possible when accesing tensor data * disable gpu backends with ngl 0 * more accurate mlock * unmap offloaded part of the model * use posix_fadvise64(.., POSIX_FADV_SEQUENTIAL) to improve performance with mmap * update quantize and lora * update session copy/set to use ggml-backend ggml-ci * use posix_fadvise instead of posix_fadvise64 * ggml_backend_alloc_ctx_tensors_from_buft : remove old print * llama_mmap::align_offset : use pointers instead of references for out parameters * restore progress_callback behavior * move final progress_callback call to load_all_data * cuda : fix fprintf format string (minor) * do not offload scales * llama_mmap : avoid unmapping the same fragments again in the destructor * remove unnecessary unmap * metal : add default log function that prints to stderr, cleanup code ggml-ci --------- Co-authored-by: Georgi Gerganov --- Makefile | 2 +- ggml-alloc.c | 16 +- ggml-backend-impl.h | 20 +- ggml-backend.c | 80 ++- ggml-backend.h | 7 + ggml-cuda.cu | 87 ++-- ggml-metal.h | 3 + ggml-metal.m | 228 +++++++-- ggml.c | 24 +- ggml.h | 13 +- llama.cpp | 1196 ++++++++++++++++++++----------------------- 11 files changed, 925 insertions(+), 751 deletions(-) diff --git a/Makefile b/Makefile index 8273f84004df69..512407a1de87b9 100644 --- a/Makefile +++ b/Makefile @@ -65,7 +65,7 @@ test: $(TEST_TARGETS) ./$$test_target; \ fi; \ if [ $$? -ne 0 ]; then \ - printf 'Test $$test_target FAILED!\n\n' $$test_target; \ + printf 'Test %s FAILED!\n\n' $$test_target; \ failures=$$(( failures + 1 )); \ else \ printf 'Test %s passed.\n\n' $$test_target; \ diff --git a/ggml-alloc.c b/ggml-alloc.c index d3049efb497a0a..a97436b17ed70d 100644 --- a/ggml-alloc.c +++ b/ggml-alloc.c @@ -449,11 +449,10 @@ static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool upd if (update_backend) { view->backend = view->view_src->backend; } - view->buffer = view->view_src->buffer; + // views are initialized in the alloc buffer rather than the view_src buffer + view->buffer = alloc->buffer; view->data = (char *)view->view_src->data + view->view_offs; - // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend - // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft); if (!alloc->measure) { @@ -736,6 +735,10 @@ void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n) { } void ggml_allocr_free(ggml_allocr_t alloc) { + if (alloc == NULL) { + return; + } + ggml_gallocr_free(alloc->galloc); ggml_tallocr_free(alloc->talloc); free(alloc); @@ -775,7 +778,7 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte } if (nbytes == 0) { - fprintf(stderr, "%s: no tensors to allocate\n", __func__); + // all the tensors in the context are already allocated return NULL; } @@ -789,6 +792,11 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte } else { ggml_backend_view_init(buffer, t); } + } else { + if (t->view_src != NULL) { + // view of a pre-allocated tensor + ggml_backend_view_init(buffer, t); + } } } diff --git a/ggml-backend-impl.h b/ggml-backend-impl.h index f588af60282650..05859935a3c2fa 100644 --- a/ggml-backend-impl.h +++ b/ggml-backend-impl.h @@ -20,6 +20,9 @@ extern "C" { size_t (*get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding bool (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend + // check if tensor data is in host memory + // should be equivalent to supports_backend(buft, ggml_backend_cpu_init()) + bool (*is_host) (ggml_backend_buffer_type_t buft); }; struct ggml_backend_buffer_type { @@ -31,15 +34,16 @@ extern "C" { typedef void * ggml_backend_buffer_context_t; struct ggml_backend_buffer_i { - void (*free_buffer)(ggml_backend_buffer_t buffer); + void (*free_buffer) (ggml_backend_buffer_t buffer); //void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras - void * (*get_base) (ggml_backend_buffer_t buffer); - void (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + void * (*get_base) (ggml_backend_buffer_t buffer); + void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); // (optional) copy tensor between different buffer-type, allow for single-copy tranfers - void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst); - void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst); + void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst); + void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst); + void (*clear) (ggml_backend_buffer_t buffer, uint8_t value); }; struct ggml_backend_buffer { @@ -78,7 +82,7 @@ extern "C" { void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); void (*cpy_tensor_to_async) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); - void (*synchronize) (ggml_backend_t backend); + void (*synchronize)(ggml_backend_t backend); // compute graph with a plan ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph); diff --git a/ggml-backend.c b/ggml-backend.c index 3a22cd085eac0a..0c8c9ec430475a 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -35,6 +35,13 @@ bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_ba return buft->iface.supports_backend(buft, backend); } +bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) { + if (buft->iface.is_host) { + return buft->iface.is_host(buft); + } + return false; +} + // backend buffer ggml_backend_buffer_t ggml_backend_buffer_init( @@ -94,6 +101,14 @@ size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct g return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type(buffer), tensor); } +void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + buffer->iface.clear(buffer, value); +} + +bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) { + return ggml_backend_buft_is_host(ggml_backend_buffer_type(buffer)); +} + ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer) { return buffer->buft; } @@ -378,7 +393,6 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { free(buffer->context); - GGML_UNUSED(buffer); } static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { @@ -411,6 +425,10 @@ static void ggml_backend_cpu_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, GGML_UNUSED(buffer); } +static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + memset(buffer->context, value, buffer->size); +} + static struct ggml_backend_buffer_i cpu_backend_buffer_i = { /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer, /* .get_base = */ ggml_backend_cpu_buffer_get_base, @@ -419,6 +437,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i = { /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, /* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from, /* .cpy_tensor_to = */ ggml_backend_cpu_buffer_cpy_tensor_to, + /* .clear = */ ggml_backend_cpu_buffer_clear, }; // for buffers from ptr, free is not called @@ -430,6 +449,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = { /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, /* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from, /* .cpy_tensor_to = */ ggml_backend_cpu_buffer_cpy_tensor_to, + /* .clear = */ ggml_backend_cpu_buffer_clear, }; static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512 @@ -455,20 +475,70 @@ static bool ggml_backend_cpu_buffer_type_supports_backend(ggml_backend_buffer_ty GGML_UNUSED(buft); } +static bool ggml_backend_cpu_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return true; + + GGML_UNUSED(buft); +} + ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) { - static struct ggml_backend_buffer_type ggml_backend_buffer_type_cpu = { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = { /* .iface = */ { /* .alloc_buffer = */ ggml_backend_cpu_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend, + /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, }, /* .context = */ NULL, }; - return &ggml_backend_buffer_type_cpu; + return &ggml_backend_cpu_buffer_type; } +#ifdef GGML_USE_CPU_HBM + +// buffer type HBM + +#include + +static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) { + hbw_free(buffer->context); +} + +static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + //void * ptr = hbw_malloc(size); + void * ptr; + int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size); + if (result != 0) { + fprintf(stderr, "failed to allocate HBM buffer of size %zu\n", size); + return NULL; + } + + // FIXME: this is a hack to avoid having to implement a new buffer type + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer; + + return buffer; +} + +ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type() { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = { + /* .iface = */ { + /* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend, + /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, + }, + /* .context = */ NULL, + }; + + return &ggml_backend_cpu_buffer_type_hbm; +} +#endif + struct ggml_backend_cpu_context { int n_threads; void * work_data; @@ -505,7 +575,7 @@ static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu)); cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); - cpu_plan->cgraph = *cgraph; + cpu_plan->cgraph = *cgraph; // FIXME: deep copy if (cpu_plan->cplan.work_size > 0) { cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); @@ -1180,7 +1250,7 @@ void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml // utils void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { GGML_ASSERT(tensor->buffer == NULL); - GGML_ASSERT(tensor->data == NULL); + //GGML_ASSERT(tensor->data == NULL); // views of pre-allocted tensors may have the data set, but still need to be initialized GGML_ASSERT(tensor->view_src != NULL); GGML_ASSERT(tensor->view_src->buffer != NULL); GGML_ASSERT(tensor->view_src->data != NULL); diff --git a/ggml-backend.h b/ggml-backend.h index 58d5ccae6ed101..a9d2fddd726a85 100644 --- a/ggml-backend.h +++ b/ggml-backend.h @@ -21,6 +21,7 @@ extern "C" { GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft); GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend); + GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft); // buffer GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); @@ -29,6 +30,8 @@ extern "C" { GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value); + GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer); GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer); // @@ -76,6 +79,10 @@ extern "C" { GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void); +#ifdef GGML_USE_CPU_HBM + GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void); +#endif + // // Backend registry // diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 32603a8d16a78d..f5e060d32ccbd0 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -9081,7 +9081,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { char * buf; CUDA_CHECK(cudaMalloc(&buf, size)); - char * buf_host = (char*)data + offset_split; + char * buf_host = (char *)data + offset_split; // set padding to 0 to avoid possible NaN values if (size > original_size) { @@ -9226,11 +9226,10 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra(); - const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || - tensor->op == GGML_OP_VIEW; + const bool inplace = tensor->view_src != nullptr; - if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) { - ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; + if (inplace && (tensor->view_src->backend == GGML_BACKEND_GPU || tensor->view_src->backend == GGML_BACKEND_GPU_SPLIT)) { + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->view_src->extra; char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; size_t view_offset = 0; if (tensor->op == GGML_OP_VIEW) { @@ -9317,7 +9316,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ if (tensor->op == GGML_OP_MUL_MAT) { if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) { #ifndef NDEBUG - fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = " PRId64 ", src1->ne[3] = " PRId64 " - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]); + fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]); #endif return false; } @@ -9523,7 +9522,7 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; if (tensor->view_src != NULL && tensor->view_offs == 0) { - assert(tensor->view_src->buffer->buft == buffer->buft); // TODO + assert(tensor->view_src->buffer->buft == buffer->buft); tensor->backend = tensor->view_src->backend; tensor->extra = tensor->view_src->extra; return; @@ -9554,23 +9553,34 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g } static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); - CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice)); + ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; - UNUSED(buffer); + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice)); } static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaMemcpy(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost)); +} - UNUSED(buffer); +static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size)); } static struct ggml_backend_buffer_i cuda_backend_buffer_interface = { @@ -9581,6 +9591,7 @@ static struct ggml_backend_buffer_i cuda_backend_buffer_interface = { /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor, /* .cpy_tensor_from = */ NULL, /* .cpy_tensor_to = */ NULL, + /* .clear = */ ggml_backend_cuda_buffer_clear, }; // cuda buffer type @@ -9632,35 +9643,36 @@ static bool ggml_backend_cuda_buffer_type_supports_backend(ggml_backend_buffer_t UNUSED(buft); } -static ggml_backend_buffer_type_i cuda_backend_buffer_type_interface = { +static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = { /* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment, /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size, /* .supports_backend = */ ggml_backend_cuda_buffer_type_supports_backend, + /* .is_host = */ nullptr, }; ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { - static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda[GGML_CUDA_MAX_DEVICES]; - static bool ggml_backend_buffer_type_cuda_initialized = false; - if (!ggml_backend_buffer_type_cuda_initialized) { + static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_types[GGML_CUDA_MAX_DEVICES]; + + static bool ggml_backend_cuda_buffer_type_initialized = false; + + if (!ggml_backend_cuda_buffer_type_initialized) { for (int i = 0; i < GGML_CUDA_MAX_DEVICES; i++) { - ggml_backend_buffer_type_cuda[i] = { - /* .iface = */ cuda_backend_buffer_type_interface, + ggml_backend_cuda_buffer_types[i] = { + /* .iface = */ ggml_backend_cuda_buffer_type_interface, /* .context = */ (ggml_backend_buffer_type_context_t) (intptr_t) i, }; } - ggml_backend_buffer_type_cuda_initialized = true; + ggml_backend_cuda_buffer_type_initialized = true; } - return &ggml_backend_buffer_type_cuda[device]; + return &ggml_backend_cuda_buffer_types[device]; } // host buffer type static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { - ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; - CUDA_CHECK(cudaFreeHost(ctx->dev_ptr)); - delete ctx; + CUDA_CHECK(cudaFreeHost(buffer->context)); } static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -9673,24 +9685,21 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer; return buffer; - - UNUSED(buft); } -struct ggml_backend_buffer_type_i cuda_backend_host_buffer_type_interface = { - /* .alloc_buffer = */ ggml_backend_cuda_host_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment, - /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, - /* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend, -}; - ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() { - static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda_host = { - /* .iface = */ cuda_backend_host_buffer_type_interface, + static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_type_host = { + /* .iface = */ { + /* .alloc_buffer = */ ggml_backend_cuda_host_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment, + /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, + /* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend, + /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, + }, /* .context = */ nullptr, }; - return &ggml_backend_buffer_type_cuda_host; + return &ggml_backend_cuda_buffer_type_host; } // backend @@ -9722,8 +9731,6 @@ static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tens ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[cuda_ctx->device][0])); @@ -9733,8 +9740,6 @@ static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggm ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[cuda_ctx->device][0])); diff --git a/ggml-metal.h b/ggml-metal.h index bf52d9cd34da48..b5e02b668a0f70 100644 --- a/ggml-metal.h +++ b/ggml-metal.h @@ -98,7 +98,10 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void); GGML_API bool ggml_backend_is_metal(ggml_backend_t backend); +GGML_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size); + GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb); + GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); // helper to check if the device supports a specific family diff --git a/ggml-metal.m b/ggml-metal.m index 465679a6bb0c87..e60b93b36a7dec 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -180,7 +180,15 @@ @interface GGMLMetalClass : NSObject @implementation GGMLMetalClass @end -ggml_log_callback ggml_metal_log_callback = NULL; + +static void ggml_metal_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) { + fprintf(stderr, "%s", msg); + + UNUSED(level); + UNUSED(user_data); +} + +ggml_log_callback ggml_metal_log_callback = ggml_metal_default_log_callback; void * ggml_metal_log_user_data = NULL; void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) { @@ -607,12 +615,24 @@ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) { } // temporarily defined here for compatibility between ggml-backend and the old API -struct ggml_backend_metal_buffer_context { - void * data; + +struct ggml_backend_metal_buffer { + void * data; + size_t size; id metal; }; +struct ggml_backend_metal_buffer_context { + void * all_data; + size_t all_size; + bool owned; + + // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap + int n_buffers; + struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; +}; + // finds the Metal buffer that contains the tensor data on the GPU device // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the // Metal buffer based on the host memory pointer @@ -622,17 +642,29 @@ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) { const int64_t tsize = ggml_nbytes(t); + ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; + // compatibility with ggml-backend - if (t->buffer && t->buffer->buft == ggml_backend_metal_buffer_type()) { - struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) t->buffer->context; + if (buffer && buffer->buft == ggml_backend_metal_buffer_type()) { + struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context; + + // find the view that contains the tensor fully + for (int i = 0; i < buf_ctx->n_buffers; ++i) { + const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data; - const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data; + //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size); + if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) { + *offs = (size_t) ioffs; - GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size); + //GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs); + + return buf_ctx->buffers[i].metal; + } + } - *offs = (size_t) ioffs; + GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name); - return buf_ctx->metal; + return nil; } // find the view that contains the tensor fully @@ -2361,6 +2393,7 @@ void ggml_metal_graph_compute( // backend interface +// default buffer static id g_backend_device = nil; static int g_backend_device_ref_count = 0; @@ -2388,34 +2421,31 @@ static void ggml_backend_metal_free_device(void) { static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) { struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; - return ctx->data; + return ctx->all_data; } static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) { struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; - [ctx->metal release]; + for (int i = 0; i < ctx->n_buffers; i++) { + [ctx->buffers[i].metal release]; + } ggml_backend_metal_free_device(); - free(ctx->data); - free(ctx); + if (ctx->owned) { + free(ctx->all_data); + } - UNUSED(buffer); + free(ctx); } static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - memcpy((char *)tensor->data + offset, data, size); UNUSED(buffer); } static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - memcpy(data, (const char *)tensor->data + offset, size); UNUSED(buffer); @@ -2433,7 +2463,13 @@ static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer UNUSED(buffer); } -static struct ggml_backend_buffer_i metal_backend_buffer_i = { +static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + memset(ctx->all_data, value, ctx->all_size); +} + +static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = { /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer, /* .get_base = */ ggml_backend_metal_buffer_get_base, /* .init_tensor = */ NULL, @@ -2441,8 +2477,11 @@ static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor, /* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from, /* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to, + /* .clear = */ ggml_backend_metal_buffer_clear, }; +// default buffer type + static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context)); @@ -2453,13 +2492,46 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba size_aligned += (size_page - (size_aligned % size_page)); } - ctx->data = ggml_metal_host_malloc(size); - ctx->metal = [ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data + id device = ggml_backend_metal_get_device(); + + ctx->all_data = ggml_metal_host_malloc(size_aligned); + ctx->all_size = size_aligned; + ctx->owned = true; + ctx->n_buffers = 1; + + ctx->buffers[0].data = ctx->all_data; + ctx->buffers[0].size = size; + ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; - return ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size); + if (ctx->buffers[0].metal == nil) { + GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + free(ctx); + ggml_backend_metal_free_device(); + return NULL; + } + + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); + + +#if TARGET_OS_OSX + GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", + device.currentAllocatedSize / 1024.0 / 1024.0, + device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); + + if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) { + GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); + } else { + GGML_METAL_LOG_INFO("\n"); + } +#else + GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0); +#endif + + + return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); } static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { @@ -2470,7 +2542,13 @@ static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_t static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend); - GGML_UNUSED(buft); + UNUSED(buft); +} + +static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return true; + + UNUSED(buft); } ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { @@ -2480,6 +2558,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment, /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend, + /* .is_host = */ ggml_backend_metal_buffer_type_is_host, }, /* .context = */ NULL, }; @@ -2487,6 +2566,87 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { return &ggml_backend_buffer_type_metal; } +// buffer from ptr + +ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) { + struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context)); + + ctx->all_data = data; + ctx->all_size = size; + ctx->owned = false; + ctx->n_buffers = 0; + + const size_t size_page = sysconf(_SC_PAGESIZE); + size_t size_aligned = size; + if ((size_aligned % size_page) != 0) { + size_aligned += (size_page - (size_aligned % size_page)); + } + + id device = ggml_backend_metal_get_device(); + + // the buffer fits into the max buffer size allowed by the device + if (size_aligned <= device.maxBufferLength) { + ctx->buffers[ctx->n_buffers].data = data; + ctx->buffers[ctx->n_buffers].size = size; + + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + return false; + } + + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); + + ++ctx->n_buffers; + } else { + // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into + // one of the views + const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case + const size_t size_step = device.maxBufferLength - size_ovlp; + const size_t size_view = device.maxBufferLength; + + for (size_t i = 0; i < size; i += size_step) { + const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); + + ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i); + ctx->buffers[ctx->n_buffers].size = size_step_aligned; + + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); + return false; + } + + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i); + if (i + size_step < size) { + GGML_METAL_LOG_INFO("\n"); + } + + ++ctx->n_buffers; + } + } + +#if TARGET_OS_OSX + GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", + device.currentAllocatedSize / 1024.0 / 1024.0, + device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); + + if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) { + GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); + } else { + GGML_METAL_LOG_INFO("\n"); + } +#else + GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0); +#endif + + return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size); +} + +// backend + static const char * ggml_backend_metal_name(ggml_backend_t backend) { return "Metal"; @@ -2499,10 +2659,6 @@ static void ggml_backend_metal_free(ggml_backend_t backend) { free(backend); } -static void ggml_backend_metal_synchronize(ggml_backend_t backend) { - UNUSED(backend); -} - static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) { return ggml_backend_metal_buffer_type(); @@ -2529,25 +2685,15 @@ static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct /* .get_tensor_async = */ NULL, /* .cpy_tensor_from_async = */ NULL, /* .cpy_tensor_to_async = */ NULL, - /* .synchronize = */ ggml_backend_metal_synchronize, - /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_metal_graph_compute, /* .supports_op = */ ggml_backend_metal_supports_op, }; -// TODO: make a common log callback for all backends in ggml-backend -static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) { - fprintf(stderr, "%s", msg); - - UNUSED(level); - UNUSED(user_data); -} - ggml_backend_t ggml_backend_metal_init(void) { - ggml_metal_log_set_callback(ggml_backend_log_callback, NULL); - struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS); if (ctx == NULL) { diff --git a/ggml.c b/ggml.c index 6da65bd9286307..2361485148a15e 100644 --- a/ggml.c +++ b/ggml.c @@ -2383,20 +2383,8 @@ size_t ggml_get_mem_size(const struct ggml_context * ctx) { size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) { size_t max_size = 0; - struct ggml_object * obj = ctx->objects_begin; - - while (obj != NULL) { - if (obj->type == GGML_OBJECT_TENSOR) { - struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs); - - const size_t size = ggml_nbytes(tensor); - - if (max_size < size) { - max_size = size; - } - } - - obj = obj->next; + for (struct ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor != NULL; tensor = ggml_get_next_tensor(ctx, tensor)) { + max_size = MAX(max_size, ggml_nbytes(tensor)); } return max_size; @@ -3093,7 +3081,7 @@ struct ggml_tensor * ggml_view_tensor( return result; } -struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx) { +struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx) { struct ggml_object * obj = ctx->objects_begin; char * const mem_buffer = ctx->mem_buffer; @@ -3109,7 +3097,7 @@ struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx) { return NULL; } -struct ggml_tensor * ggml_get_next_tensor(struct ggml_context * ctx, struct ggml_tensor * tensor) { +struct ggml_tensor * ggml_get_next_tensor(const struct ggml_context * ctx, struct ggml_tensor * tensor) { struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE); obj = obj->next; @@ -19213,6 +19201,10 @@ char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) { return ctx->infos[i].name.data; } +enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int i) { + return ctx->infos[i].type; +} + // returns the index static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) { const int idx = gguf_find_key(ctx, key); diff --git a/ggml.h b/ggml.h index beacdc8be015fc..b17314897b35c6 100644 --- a/ggml.h +++ b/ggml.h @@ -735,8 +735,8 @@ extern "C" { GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src); // Context tensor enumeration and lookup - GGML_API struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx); - GGML_API struct ggml_tensor * ggml_get_next_tensor (struct ggml_context * ctx, struct ggml_tensor * tensor); + GGML_API struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx); + GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor); GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name); GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); @@ -2135,10 +2135,11 @@ extern "C" { GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id); GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i); - GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx); - GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name); - GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i); - GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i); + GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx); + GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name); + GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i); + GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i); + GGML_API enum ggml_type gguf_get_tensor_type (const struct gguf_context * ctx, int i); // overrides existing values or adds a new one GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val); diff --git a/llama.cpp b/llama.cpp index 63ebe581bfae6b..ba970ce8d18091 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1,11 +1,12 @@ #define LLAMA_API_INTERNAL +//#define LLAMA_GGML_BACKEND_CUDA_TEST // for testing only - enables ggml-cuda through ggml-backend, disables partial offloading #include "llama.h" #include "unicode.h" #include "ggml.h" - #include "ggml-alloc.h" +#include "ggml-backend.h" #ifdef GGML_USE_CUBLAS # include "ggml-cuda.h" @@ -32,6 +33,7 @@ #include #if defined(_POSIX_MAPPED_FILES) #include + #include #endif #if defined(_POSIX_MEMLOCK_RANGE) #include @@ -712,38 +714,6 @@ static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * // llama helpers // -inline void * llama_host_malloc(size_t n) { -#ifdef GGML_USE_CUBLAS - if (ggml_cublas_loaded()) { - return ggml_cuda_host_malloc(n); - } else { - return malloc(n); - } -#elif GGML_USE_METAL - return ggml_metal_host_malloc(n); -#elif GGML_USE_CPU_HBM - return hbw_malloc(n); -#else - return malloc(n); -#endif -} - -inline void llama_host_free(void * ptr) { -#ifdef GGML_USE_CUBLAS - if (ggml_cublas_loaded()) { - return ggml_cuda_host_free(ptr); - } else { - return free(ptr); - } -#elif GGML_USE_METAL - return ggml_metal_host_free(ptr); -#elif GGML_USE_CPU_HBM - return hbw_free(ptr); -#else - return free(ptr); -#endif -} - #if defined(_WIN32) static std::string llama_format_win_err(DWORD err) { LPSTR buf; @@ -758,40 +728,10 @@ static std::string llama_format_win_err(DWORD err) { } #endif -struct llama_buffer { - void * data = NULL; - size_t size = 0; - - // fallback to malloc / free - // useful in cases where CUDA can try to allocate PINNED memory - bool fallback = false; - - void resize(size_t n) { - llama_host_free(data); - - data = llama_host_malloc(n); - if (!data) { - fallback = true; - data = malloc(n); - } else { - fallback = false; - } - - GGML_ASSERT(data); - size = n; - } - - ~llama_buffer() { - if (data) { - if (fallback) { // NOLINT - free(data); - } else { - llama_host_free(data); - } - } - - data = NULL; - } +template +struct no_init { + T value; + no_init() { /* do nothing */ } }; struct llama_file { @@ -879,6 +819,9 @@ struct llama_mmap { #ifdef _POSIX_MAPPED_FILES static constexpr bool SUPPORTED = true; + // list of mapped fragments (first_offset, last_offset) + std::vector> mapped_fragments; + llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false) { size = file->size; int fd = fileno(file->fp); @@ -886,17 +829,22 @@ struct llama_mmap { // prefetch/readahead impairs performance on NUMA systems if (numa) { prefetch = 0; } #ifdef __linux__ + // advise the kernel to read the file sequentially (increases readahead) + if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) { + LLAMA_LOG_WARN("warning: posix_fadvise(.., POSIX_FADV_SEQUENTIAL) failed: %s\n", + strerror(errno)); + } if (prefetch) { flags |= MAP_POPULATE; } #endif addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0); - if (addr == MAP_FAILED) { + if (addr == MAP_FAILED) { // NOLINT throw std::runtime_error(format("mmap failed: %s", strerror(errno))); } if (prefetch > 0) { - // Advise the kernel to preload the mapped memory + // advise the kernel to preload the mapped memory if (posix_madvise(addr, std::min(file->size, prefetch), POSIX_MADV_WILLNEED)) { - fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n", + LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n", strerror(errno)); } } @@ -904,14 +852,81 @@ struct llama_mmap { // advise the kernel not to use readahead // (because the next page might not belong on the same node) if (posix_madvise(addr, file->size, POSIX_MADV_RANDOM)) { - fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n", + LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n", strerror(errno)); } } + + // initialize list of mapped_fragments + mapped_fragments.emplace_back(0, file->size); + } + + static void align_range(size_t * first, size_t * last, size_t page_size) { + // align first to the next page + size_t offset_in_page = *first & (page_size - 1); + size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page; + *first += offset_to_page; + + // align last to the previous page + *last = *last & ~(page_size - 1); + + if (*last <= *first) { + *last = *first; + } + } + + // partially unmap the file in the range [first, last) + void unmap_fragment(size_t first, size_t last) { + // note: this function must not be called multiple times with overlapping ranges + // otherwise, there is a risk of invalidating addresses that have been repurposed for other mappings + int page_size = sysconf(_SC_PAGESIZE); + align_range(&first, &last, page_size); + size_t len = last - first; + + if (len == 0) { + return; + } + + GGML_ASSERT(first % page_size == 0); + GGML_ASSERT(last % page_size == 0); + GGML_ASSERT(last > first); + + void * next_page_start = (uint8_t *) addr + first; + + // unmap the range + if (munmap(next_page_start, len)) { + LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno)); + } + + // update the list of mapped fragments to avoid unmapping the same range again in the destructor + std::vector> new_mapped_fragments; + for (const auto & frag : mapped_fragments) { + if (frag.first < first && frag.second > last) { + // the range is in the middle of the fragment, split it + new_mapped_fragments.emplace_back(frag.first, first); + new_mapped_fragments.emplace_back(last, frag.second); + } else if (frag.first < first && frag.second > first) { + // the range starts in the middle of the fragment + new_mapped_fragments.emplace_back(frag.first, first); + } else if (frag.first < last && frag.second > last) { + // the range ends in the middle of the fragment + new_mapped_fragments.emplace_back(last, frag.second); + } else if (frag.first >= first && frag.second <= last) { + // the range covers the entire fragment + } else { + // the range is outside the fragment + new_mapped_fragments.push_back(frag); + } + } + mapped_fragments = std::move(new_mapped_fragments); } ~llama_mmap() { - munmap(addr, size); + for (const auto & frag : mapped_fragments) { + if (munmap((char *) addr + frag.first, frag.second - frag.first)) { + LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno)); + } + } } #elif defined(_WIN32) static constexpr bool SUPPORTED = true; @@ -959,6 +974,12 @@ struct llama_mmap { } } + void unmap_fragment(size_t first, size_t last) { + // not supported + GGML_UNUSED(first); + GGML_UNUSED(last); + } + ~llama_mmap() { if (!UnmapViewOfFile(addr)) { fprintf(stderr, "warning: UnmapViewOfFile failed: %s\n", @@ -975,6 +996,13 @@ struct llama_mmap { throw std::runtime_error(std::string("mmap not supported")); } + + void unmap(size_t offset, size_t len) { + (void) offset; + (void) len; + + throw std::runtime_error(std::string("mmap not supported")); + } #endif }; @@ -1148,6 +1176,26 @@ static std::string llama_token_to_piece(const struct llama_context * ctx, llama_ return std::string(result.data(), result.size()); } +static ggml_backend_buffer_type_t llama_default_buffer_type(int n_gpu_layers) { +#ifdef GGML_USE_METAL + if (n_gpu_layers > 0) { + return ggml_backend_metal_buffer_type(); + } +#elif defined(GGML_USE_CUBLAS) && defined(LLAMA_GGML_BACKEND_CUDA_TEST) + if (n_gpu_layers > 0) { + return ggml_backend_cuda_buffer_type(0); + } +#elif defined(GGML_USE_CUBLAS) + return ggml_backend_cuda_host_buffer_type(); +#elif defined(GGML_USE_CPU_HBM) + return ggml_backend_cpu_hbm_buffer_type(); +#endif + + return ggml_backend_cpu_buffer_type(); + + GGML_UNUSED(n_gpu_layers); +} + // // globals // @@ -1348,14 +1396,10 @@ struct llama_kv_cache { struct ggml_context * ctx = NULL; - llama_buffer buf; + ggml_backend_buffer_t buf = NULL; ~llama_kv_cache() { - if (ctx) { - ggml_free(ctx); - } - -#ifdef GGML_USE_CUBLAS +#if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST) if (ggml_cublas_loaded()) { for (size_t i = 0; i < k_l.size(); ++i) { ggml_cuda_free_data(k_l[i]); @@ -1363,6 +1407,11 @@ struct llama_kv_cache { } } #endif + if (ctx) { + ggml_free(ctx); + } + + ggml_backend_buffer_free(buf); } }; @@ -1402,11 +1451,11 @@ struct llama_vocab { id special_suffix_id = 32008; id special_eot_id = 32010; - int find_bpe_rank(std::string token_left, std::string token_right) const { - GGML_ASSERT(token_left.find(" ") == std::string::npos); - GGML_ASSERT(token_left.find("\n") == std::string::npos); - GGML_ASSERT(token_right.find(" ") == std::string::npos); - GGML_ASSERT(token_right.find("\n") == std::string::npos); + int find_bpe_rank(const std::string & token_left, const std::string & token_right) const { + GGML_ASSERT(token_left.find(' ') == std::string::npos); + GGML_ASSERT(token_left.find('\n') == std::string::npos); + GGML_ASSERT(token_right.find(' ') == std::string::npos); + GGML_ASSERT(token_right.find('\n') == std::string::npos); auto it = bpe_ranks.find(std::make_pair(token_left, token_right)); if (it == bpe_ranks.end()) { @@ -1448,7 +1497,7 @@ struct llama_model { struct ggml_context * ctx = NULL; // the model memory buffer - llama_buffer buf; + ggml_backend_buffer_t buf = NULL; // model memory mapped file std::unique_ptr mapping; @@ -1464,11 +1513,7 @@ struct llama_model { int64_t t_start_us = 0; ~llama_model() { - if (ctx) { - ggml_free(ctx); - } - -#ifdef GGML_USE_CUBLAS +#if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST) if (ggml_cublas_loaded()) { for (size_t i = 0; i < tensors_by_name.size(); ++i) { ggml_cuda_free_data(tensors_by_name[i].second); @@ -1482,24 +1527,26 @@ struct llama_model { ggml_cl_free_data(tensors_by_name[i].second); } #endif + if (ctx) { + ggml_free(ctx); + } + + ggml_backend_buffer_free(buf); } }; struct llama_context { llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {} ~llama_context() { -#ifdef GGML_USE_METAL - if (ctx_metal) { - ggml_metal_free(ctx_metal); - } -#endif - if (alloc) { - ggml_allocr_free(alloc); - } + ggml_allocr_free(alloc); + ggml_backend_buffer_free(buf_alloc); + ggml_backend_free(backend); } llama_cparams cparams; + ggml_backend_t backend = nullptr; + const llama_model & model; // key + value cache for the self attention @@ -1530,18 +1577,13 @@ struct llama_context { // input embedding (1-dimensional array: [n_embd]) std::vector embedding; - // reusable buffer for `struct ggml_graph_plan.work_data` - std::vector work_buffer; - // memory buffers used to evaluate the model - llama_buffer buf_compute; - - llama_buffer buf_alloc; + std::vector buf_compute_meta; + ggml_backend_buffer_t buf_alloc = NULL; ggml_allocr * alloc = NULL; -#ifdef GGML_USE_METAL - ggml_metal_context * ctx_metal = NULL; -#endif + // temporary buffer for copying data to/from the backend + std::vector> buf_copy; #ifdef GGML_USE_MPI ggml_mpi_context * ctx_mpi = NULL; @@ -1563,9 +1605,6 @@ static bool llama_kv_cache_init( const uint32_t n_embd = hparams.n_embd_gqa(); const uint32_t n_layer = hparams.n_layer; - const int64_t n_mem = n_layer*n_ctx; - const int64_t n_elements = n_embd*n_mem; - cache.has_shift = false; cache.head = 0; @@ -1575,13 +1614,10 @@ static bool llama_kv_cache_init( cache.cells.clear(); cache.cells.resize(n_ctx); - cache.buf.resize(ggml_row_size(ktype, n_elements) + ggml_row_size(vtype, n_elements) + 2u*n_layer*ggml_tensor_overhead()); - memset(cache.buf.data, 0, cache.buf.size); - struct ggml_init_params params; - params.mem_size = cache.buf.size; - params.mem_buffer = cache.buf.data; - params.no_alloc = false; + params.mem_size = 2u*n_layer*ggml_tensor_overhead(); + params.mem_buffer = NULL; + params.no_alloc = true; cache.ctx = ggml_init(params); @@ -1595,9 +1631,7 @@ static bool llama_kv_cache_init( cache.k_l.reserve(n_layer); cache.v_l.reserve(n_layer); - const int i_gpu_start = (int) n_layer - n_gpu_layers; GGML_UNUSED(i_gpu_start); - - GGML_UNUSED(offload); + const int i_gpu_start = (int) n_layer - n_gpu_layers; for (int i = 0; i < (int) n_layer; i++) { ggml_tensor * k = ggml_new_tensor_1d(cache.ctx, ktype, n_embd*n_ctx); @@ -1606,23 +1640,35 @@ static bool llama_kv_cache_init( ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); cache.v_l.push_back(v); -#ifdef GGML_USE_CUBLAS +#if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST) if (i >= i_gpu_start) { if (offload) { ggml_cuda_assign_buffers_no_scratch(k); - vram_kv_cache += ggml_nbytes(k); ggml_cuda_assign_buffers_no_scratch(v); + vram_kv_cache += ggml_nbytes(k); vram_kv_cache += ggml_nbytes(v); + // HACK: mark tensor as allocated + k->data = v->data = (void *)(uintptr_t)1; } } #endif // GGML_USE_CUBLAS } + // allocate tensors + cache.buf = ggml_backend_alloc_ctx_tensors_from_buft(cache.ctx, llama_default_buffer_type(n_gpu_layers)); + + // buf may be NULL with full offload + if (cache.buf) { + // initialize the buffer to avoid NaNs in the padding + ggml_backend_buffer_clear(cache.buf, 0); + } + if (vram_kv_cache > 0) { LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0); } - GGML_UNUSED(n_gpu_layers); + GGML_UNUSED(i_gpu_start); + GGML_UNUSED(offload); return true; } @@ -2073,14 +2119,13 @@ struct llama_model_loader { enum ggml_type type_max = GGML_TYPE_F32; for (int i = 0; i < n_tensors; i++) { - const char * name = gguf_get_tensor_name(ctx_gguf, i); - struct ggml_tensor * meta = ggml_get_tensor(ctx_meta, name); + enum ggml_type type = gguf_get_tensor_type(ctx_gguf, i); - n_type[meta->type]++; + n_type[type]++; - if (n_type_max < n_type[meta->type]) { - n_type_max = n_type[meta->type]; - type_max = meta->type; + if (n_type_max < n_type[type]) { + n_type_max = n_type[type]; + type_max = type; } // LLAMA_LOG_INFO("%s: - tensor %4d: %32s %-8s [ %s ]\n", __func__, i, name, ggml_type_name(meta->type), llama_format_tensor_shape(meta).c_str()); @@ -2221,34 +2266,19 @@ struct llama_model_loader { return gguf_get_tensor_name(ctx_gguf, i); } - struct ggml_tensor * get_tensor_meta(int i) const { - return ggml_get_tensor(ctx_meta, get_tensor_name(i)); + struct ggml_tensor * get_tensor_meta(const char * name) const { + return ggml_get_tensor(ctx_meta, name); } - void calc_sizes(size_t & ctx_size_p, size_t & mmapped_size_p) const { - ctx_size_p = 0; - mmapped_size_p = 0; - - for (int i = 0; i < n_tensors; i++) { - struct ggml_tensor * meta = get_tensor_meta(i); - ctx_size_p += sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE; - (use_mmap ? mmapped_size_p : ctx_size_p) += ggml_nbytes_pad(meta); - } + struct ggml_tensor * get_tensor_meta(int i) const { + return get_tensor_meta(get_tensor_name(i)); } struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, struct ggml_tensor * meta, ggml_backend_type backend) { - if (backend != GGML_BACKEND_CPU) { - ggml_set_no_alloc(ctx, true); - } - struct ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); tensor->backend = backend; // TODO: ggml_set_backend ggml_set_name(tensor, ggml_get_name(meta)); - if (backend != GGML_BACKEND_CPU) { - ggml_set_no_alloc(ctx, use_mmap); - } - n_created++; return tensor; @@ -2306,90 +2336,137 @@ struct llama_model_loader { return gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, idx); } + void init_mapping(bool prefetch = true) { + /* + // prefetch only CPU tensors + if (use_mmap) { + size_t size_pref = 0; // prefetch + + for (int i = 0; i < gguf_get_n_tensors(ctx_gguf); i++) { + struct ggml_tensor * cur = ggml_get_tensor(ctx, gguf_get_tensor_name(ctx_gguf, i)); + if (cur->backend == GGML_BACKEND_CPU) { + size_t tensor_end = gguf_get_tensor_offset(ctx_gguf, i) + ggml_nbytes(cur); + size_pref = std::max(size_pref, tensor_end); + } + } + mapping.reset(new llama_mmap(&file, gguf_get_data_offset(ctx_gguf) + size_pref, ggml_is_numa())); + } + */ + // prefetch the whole file - all the data is needed anyway + if (use_mmap) { + mapping.reset(new llama_mmap(&file, prefetch ? -1 : 0, ggml_is_numa())); + } + } + + // for backwards compatibility, does not support ggml-backend void load_data_for(struct ggml_tensor * cur) const { const size_t offs = file_offset(ggml_get_name(cur)); - if (use_mmap) { - cur->data = (uint8_t *) mapping->addr + offs; + if (use_mmap && mapping) { + GGML_ASSERT(cur->data == nullptr); + cur->data = (uint8_t *)mapping->addr + offs; } else { + GGML_ASSERT(cur->data != nullptr); file.seek(offs, SEEK_SET); file.read_raw(cur->data, ggml_nbytes(cur)); } } - void load_all_data(struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) { + void load_all_data(struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, ggml_backend_buffer_t buf_mmap, llama_mlock * lmlock) const { size_t size_data = 0; - size_t size_lock = 0; - size_t size_pref = 0; // prefetch for (int i = 0; i < gguf_get_n_tensors(ctx_gguf); i++) { struct ggml_tensor * cur = ggml_get_tensor(ctx, gguf_get_tensor_name(ctx_gguf, i)); size_data += ggml_nbytes(cur); - if (cur->backend == GGML_BACKEND_CPU) { - size_pref += ggml_nbytes(cur); - } } - if (use_mmap) { - mapping.reset(new llama_mmap(&file, size_pref, ggml_is_numa())); + if (use_mmap && buf_mmap) { if (lmlock) { lmlock->init(mapping->addr); } } - size_t done_size = 0; +#if (defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST)) || defined(GGML_USE_CLBLAST) + const bool legacy_offload = true; +#else + const bool legacy_offload = false; +#endif + + std::vector> read_buf; + + size_t size_done = 0; + + size_t mmap_first = -1; + size_t mmap_last = 0; + for (int i = 0; i < gguf_get_n_tensors(ctx_gguf); i++) { struct ggml_tensor * cur = ggml_get_tensor(ctx, gguf_get_tensor_name(ctx_gguf, i)); GGML_ASSERT(cur); // unused tensors should have been caught by load_data already if (progress_callback) { - progress_callback((float) done_size / size_data, progress_callback_user_data); - } - - // allocate temp buffer if not using mmap - if (!use_mmap && cur->data == NULL) { - GGML_ASSERT(cur->backend != GGML_BACKEND_CPU); - #ifdef GGML_USE_CPU_HBM - cur->data = (uint8_t*)hbw_malloc(ggml_nbytes(cur)); - #else - cur->data = (uint8_t*)malloc(ggml_nbytes(cur)); - #endif + progress_callback((float) size_done / size_data, progress_callback_user_data); } - load_data_for(cur); + const size_t offs = file_offset(ggml_get_name(cur)); - switch (cur->backend) { - case GGML_BACKEND_CPU: - if (use_mmap && lmlock) { - size_lock += ggml_nbytes(cur); - lmlock->grow_to(size_lock); + if (!legacy_offload || cur->backend == GGML_BACKEND_CPU) { + if (use_mmap && mapping) { + if (buf_mmap) { + ggml_backend_tensor_alloc(buf_mmap, cur, (uint8_t *) mapping->addr + offs); + if (lmlock) { + lmlock->grow_to(offs + ggml_nbytes(cur)); + } + mmap_first = std::min(mmap_first, offs); + mmap_last = std::max(mmap_last, offs + ggml_nbytes(cur)); + } else { + ggml_backend_tensor_set(cur, (uint8_t *) mapping->addr + offs, 0, ggml_nbytes(cur)); } - break; -#ifdef GGML_USE_CUBLAS - case GGML_BACKEND_GPU: - case GGML_BACKEND_GPU_SPLIT: - // old code: - //ggml_cuda_transform_tensor(lt.data, lt.ggml_tensor); - - // TODO: test if this works !! - ggml_cuda_transform_tensor(cur->data, cur); - if (!use_mmap) { - free(cur->data); + } else { + if (ggml_backend_buffer_is_host(cur->buffer)) { + file.seek(offs, SEEK_SET); + file.read_raw(cur->data, ggml_nbytes(cur)); + } else { + read_buf.resize(ggml_nbytes(cur)); + file.seek(offs, SEEK_SET); + file.read_raw(read_buf.data(), ggml_nbytes(cur)); + ggml_backend_tensor_set(cur, read_buf.data(), 0, ggml_nbytes(cur)); } - break; + } + } else { + // HACK: mark tensor as allocated + cur->data = (void *)(uintptr_t)1; + void * data; + if (use_mmap && mapping) { + data = (uint8_t *) mapping->addr + offs; + } else { + read_buf.resize(ggml_nbytes(cur)); + file.seek(offs, SEEK_SET); + file.read_raw(read_buf.data(), ggml_nbytes(cur)); + data = read_buf.data(); + } + +#if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST) + ggml_cuda_transform_tensor(data, cur); #elif defined(GGML_USE_CLBLAST) - case GGML_BACKEND_GPU: - ggml_cl_transform_tensor(cur->data, cur); - if (!use_mmap) { - free(cur->data); - } - break; + GGML_ASSERT(cur->backend == GGML_BACKEND_GPU); + ggml_cl_transform_tensor(data, cur); +#else + GGML_ASSERT(!"GPU tensor without a GPU backend"); + GGML_UNUSED(data); #endif - default: - continue; } - done_size += ggml_nbytes(cur); + size_done += ggml_nbytes(cur); + } + + // unmap offloaded tensors and metadata + if (use_mmap && mapping) { + mapping->unmap_fragment(0, mmap_first); + mapping->unmap_fragment(mmap_last, mapping->size); + } + + if (progress_callback) { + progress_callback(1.0f, progress_callback_user_data); } } }; @@ -2983,25 +3060,16 @@ static void llm_load_tensors( model.n_gpu_layers = n_gpu_layers; - size_t ctx_size; - size_t mmapped_size; - - ml.calc_sizes(ctx_size, mmapped_size); + size_t ctx_size = ggml_tensor_overhead() * ml.n_tensors; - LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MiB\n", __func__, ctx_size/1024.0/1024.0); + LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MiB\n", __func__, ctx_size/1024.0/1024.0); // create the ggml context { - model.buf.resize(ctx_size); - if (use_mlock) { - model.mlock_buf.init (model.buf.data); - model.mlock_buf.grow_to(model.buf.size); - } - struct ggml_init_params params = { - /*.mem_size =*/ model.buf.size, - /*.mem_buffer =*/ model.buf.data, - /*.no_alloc =*/ ml.use_mmap, + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, }; model.ctx = ggml_init(params); @@ -3015,22 +3083,21 @@ static void llm_load_tensors( enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU; enum ggml_backend_type llama_backend_offload_split = GGML_BACKEND_CPU; -#ifdef GGML_USE_CUBLAS +#if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST) if (ggml_cublas_loaded()) { LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__); ggml_cuda_set_main_device(main_gpu); - llama_backend_offload = GGML_BACKEND_GPU; + llama_backend_offload = GGML_BACKEND_GPU; llama_backend_offload_split = GGML_BACKEND_GPU_SPLIT; } #elif defined(GGML_USE_CLBLAST) LLAMA_LOG_INFO("%s: using OpenCL for GPU acceleration\n", __func__); - llama_backend_offload = GGML_BACKEND_GPU; + llama_backend_offload = GGML_BACKEND_GPU; llama_backend_offload_split = GGML_BACKEND_GPU; #endif - // prepare memory for the weights - size_t vram_weights = 0; + // create tensors for the weights { const int64_t n_embd = hparams.n_embd; const int64_t n_embd_gqa = hparams.n_embd_gqa(); @@ -3059,13 +3126,6 @@ static void llm_load_tensors( model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } } const uint32_t n_ff = hparams.n_ff; @@ -3115,28 +3175,6 @@ static void llm_load_tensors( layer.ffn_up_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), {n_embd, n_ff}, backend_split); } } - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + - ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + - (layer.bq ? ggml_nbytes(layer.bq) : 0) + - (layer.bk ? ggml_nbytes(layer.bk) : 0) + - (layer.bv ? ggml_nbytes(layer.bv) : 0) + - (layer.bo ? ggml_nbytes(layer.bo) : 0) + - ggml_nbytes(layer.ffn_norm); - - if (layer.ffn_gate_inp == nullptr) { - vram_weights += - ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); - } else { - vram_weights += ggml_nbytes(layer.ffn_gate_inp); - for (uint32_t x = 0; x < hparams.n_expert; ++x) { - vram_weights += - ggml_nbytes(layer.ffn_gate_exp[x]) + ggml_nbytes(layer.ffn_down_exp[x]) + ggml_nbytes(layer.ffn_up_exp[x]); - } - } - } } } break; case LLM_ARCH_BAICHUAN: @@ -3156,13 +3194,6 @@ static void llm_load_tensors( model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } } const uint32_t n_ff = hparams.n_ff; @@ -3189,19 +3220,10 @@ static void llm_load_tensors( layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + - ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + - ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); - } } } break; case LLM_ARCH_FALCON: { - // TODO: CPU-only for now - model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); // output @@ -3220,14 +3242,6 @@ static void llm_load_tensors( model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - vram_weights += ggml_nbytes(model.output_norm_b); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } } const uint32_t n_ff = hparams.n_ff; @@ -3248,11 +3262,6 @@ static void llm_load_tensors( if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) { layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend); layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, backend); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(layer.attn_norm_2); - vram_weights += ggml_nbytes(layer.attn_norm_2_b); - } } layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); @@ -3260,13 +3269,6 @@ static void llm_load_tensors( layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + - ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.wo) + - ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); - } } } break; case LLM_ARCH_STARCODER: @@ -3290,14 +3292,6 @@ static void llm_load_tensors( model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - vram_weights += ggml_nbytes(model.output_norm_b); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } } const uint32_t n_ff = hparams.n_ff; @@ -3329,16 +3323,6 @@ static void llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + - ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + - ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) + - ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm_b) + - ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_down_b) + - ggml_nbytes(layer.ffn_up) + ggml_nbytes(layer.ffn_up_b); - } } } break; case LLM_ARCH_PERSIMMON: @@ -3360,14 +3344,6 @@ static void llm_load_tensors( model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - vram_weights += ggml_nbytes(model.output_norm_b); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } } const uint32_t n_ff = hparams.n_ff; @@ -3397,8 +3373,6 @@ static void llm_load_tensors( } break; case LLM_ARCH_BLOOM: { - // TODO: CPU-only for now - model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); model.tok_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, GGML_BACKEND_CPU); model.tok_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, GGML_BACKEND_CPU); @@ -3419,14 +3393,6 @@ static void llm_load_tensors( model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - vram_weights += ggml_nbytes(model.output_norm_b); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } } const uint32_t n_ff = hparams.n_ff; @@ -3458,16 +3424,6 @@ static void llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + - ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + - ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) + - ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm_b) + - ggml_nbytes(layer.ffn_up) + ggml_nbytes(layer.ffn_up_b) + - ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_down_b); - } } } break; case LLM_ARCH_MPT: @@ -3489,13 +3445,6 @@ static void llm_load_tensors( model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } } const uint32_t n_ff = hparams.n_ff; @@ -3518,16 +3467,6 @@ static void llm_load_tensors( layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + - ggml_nbytes(layer.wqkv) + - ggml_nbytes(layer.wo) + - ggml_nbytes(layer.ffn_norm) + - ggml_nbytes(layer.ffn_down) + - ggml_nbytes(layer.ffn_up); - } } } break; case LLM_ARCH_STABLELM: @@ -3550,13 +3489,6 @@ static void llm_load_tensors( model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } } const uint32_t n_ff = hparams.n_ff; @@ -3588,13 +3520,6 @@ static void llm_load_tensors( layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + - ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + - ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); - } } } break; case LLM_ARCH_QWEN: @@ -3614,14 +3539,7 @@ static void llm_load_tensors( model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } - } + } const uint32_t n_ff = hparams.n_ff / 2; @@ -3646,13 +3564,6 @@ static void llm_load_tensors( layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + - ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_gate) + - ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); - } } } break; case LLM_ARCH_PHI2: @@ -3676,13 +3587,6 @@ static void llm_load_tensors( model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); model.output_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - vram_weights += ggml_nbytes(model.output_norm_b); - vram_weights += ggml_nbytes(model.output); - vram_weights += ggml_nbytes(model.output_b); - } } const uint32_t n_ff = hparams.n_ff; @@ -3711,15 +3615,6 @@ static void llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + - ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + - ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) + - ggml_nbytes(layer.ffn_up) + ggml_nbytes(layer.ffn_up_b) + - ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_down_b); - } } } break; default: @@ -3729,16 +3624,78 @@ static void llm_load_tensors( ml.done_getting_tensors(); + ml.init_mapping(); + + // allocate tensors + size_t vram_weights = 0; + size_t buf_size = 0; + + ggml_backend_buffer_type_t buft = llama_default_buffer_type(n_gpu_layers); + + for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + // GGML_BACKEND_GPU tensors are for CUDA and OpenCL only, which are handled separately without ggml-backend + if (t->backend == GGML_BACKEND_CPU) { + buf_size += GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), ggml_backend_buft_get_alignment(buft)); + } else { + vram_weights += ggml_nbytes(t); + } + } + + // create backend buffer + ggml_backend_buffer_t buf_mmap = nullptr; + +#ifdef GGML_USE_METAL + if (n_gpu_layers > 0) { + if (ml.use_mmap) { + const size_t max_size = ggml_get_max_tensor_size(ctx); + model.buf = ggml_backend_metal_buffer_from_ptr(ml.mapping->addr, ml.mapping->size, max_size); + buf_mmap = model.buf; + } else { + model.buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_metal_buffer_type()); + } + } +#elif defined(GGML_USE_CUBLAS) && defined(LLAMA_GGML_BACKEND_CUDA_TEST) + // for testing only + if (n_gpu_layers > 0) { + model.buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_cuda_buffer_type(0)); + } +#endif + + if (model.buf == nullptr) { + // CPU backend, and indirectly CUDA and OpenCL + if (ml.use_mmap) { + model.buf = ggml_backend_cpu_buffer_from_ptr(ml.mapping->addr, ml.mapping->size); + buf_mmap = model.buf; + } else { + // allocate only CPU tensors + model.buf = ggml_backend_buft_alloc_buffer(buft, buf_size); + ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(model.buf); + for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + if (t->backend == GGML_BACKEND_CPU) { + ggml_tallocr_alloc(alloc, t); + } + } + ggml_tallocr_free(alloc); + } + } + + if (use_mlock && ggml_backend_buffer_is_host(model.buf)) { + model.mlock_buf.init (ggml_backend_buffer_get_base(model.buf)); + model.mlock_buf.grow_to(ggml_backend_buffer_get_size(model.buf)); + } + // print memory requirements { - // this is the total memory required to run the inference - size_t mem_required = - ctx_size + - mmapped_size - vram_weights; // weights in VRAM not in memory + size_t sys_mem_required = ctx_size + buf_size; - LLAMA_LOG_INFO("%s: mem required = %7.2f MiB\n", __func__, mem_required / 1024.0 / 1024.0); + if (sys_mem_required > 0) { + LLAMA_LOG_INFO("%s: system memory used = %7.2f MiB\n", __func__, sys_mem_required / 1024.0 / 1024.0); + } + if (vram_weights > 0) { + LLAMA_LOG_INFO("%s: VRAM used = %7.2f MiB\n", __func__, vram_weights / 1024.0 / 1024.0); + } -#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) +#if (defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST)) || defined(GGML_USE_CLBLAST) const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); @@ -3746,39 +3703,26 @@ static void llm_load_tensors( LLAMA_LOG_INFO("%s: offloading non-repeating layers to GPU\n", __func__); } -#ifdef GGML_USE_CUBLAS - const int max_backend_supported_layers = hparams.n_layer + 1; - const int max_offloadable_layers = hparams.n_layer + 1; -#elif GGML_USE_CLBLAST const int max_backend_supported_layers = hparams.n_layer + 1; const int max_offloadable_layers = hparams.n_layer + 1; -#endif // GGML_USE_CUBLAS LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); - LLAMA_LOG_INFO("%s: VRAM used: %.2f MiB\n", __func__, vram_weights / 1024.0 / 1024.0); -#else - (void) n_gpu_layers; #endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) } - // populate `tensors_by_name` +#if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST) + ggml_cuda_set_tensor_split(tensor_split); +#else + GGML_UNUSED(tensor_split); +#endif // GGML_USE_CUBLAS + + // populate tensors_by_name for (int i = 0; i < ml.n_tensors; ++i) { struct ggml_tensor * cur = ggml_get_tensor(ctx, ml.get_tensor_name(i)); model.tensors_by_name.emplace_back(ggml_get_name(cur), cur); } - (void) tensor_split; -#ifdef GGML_USE_CUBLAS - { - ggml_cuda_set_tensor_split(tensor_split); - } -#endif - - ml.load_all_data(ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL); - - if (progress_callback) { - progress_callback(1.0f, progress_callback_user_data); - } + ml.load_all_data(ctx, progress_callback, progress_callback_user_data, buf_mmap, use_mlock ? &model.mlock_mmap : NULL); model.mapping = std::move(ml.mapping); @@ -4211,7 +4155,7 @@ struct llm_build_context { const llm_build_cb & cb; - llama_buffer & buf_compute; + std::vector & buf_compute_meta; struct ggml_context * ctx0 = nullptr; @@ -4221,35 +4165,35 @@ struct llm_build_context { const llama_batch & batch, const llm_build_cb & cb, bool worst_case) : - model (lctx.model), - hparams (model.hparams), - cparams (lctx.cparams), - batch (batch), - kv_self (lctx.kv_self), - n_embd (hparams.n_embd), - n_layer (hparams.n_layer), - n_ctx (cparams.n_ctx), - n_head (hparams.n_head), - n_head_kv (hparams.n_head_kv), - n_embd_head (hparams.n_embd_head()), - n_embd_gqa (hparams.n_embd_gqa()), - n_expert (hparams.n_expert), - n_expert_used (hparams.n_expert_used), - freq_base (cparams.rope_freq_base), - freq_scale (cparams.rope_freq_scale), - ext_factor (cparams.yarn_ext_factor), - attn_factor (cparams.yarn_attn_factor), - beta_fast (cparams.yarn_beta_fast), - beta_slow (cparams.yarn_beta_slow), - norm_eps (hparams.f_norm_eps), - norm_rms_eps (hparams.f_norm_rms_eps), - n_tokens (batch.n_tokens), - n_kv (worst_case ? n_ctx : kv_self.n), - kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), - n_orig_ctx (cparams.n_yarn_orig_ctx), - do_rope_shift (worst_case || kv_self.has_shift), - cb (cb), - buf_compute (lctx.buf_compute) { + model (lctx.model), + hparams (model.hparams), + cparams (lctx.cparams), + batch (batch), + kv_self (lctx.kv_self), + n_embd (hparams.n_embd), + n_layer (hparams.n_layer), + n_ctx (cparams.n_ctx), + n_head (hparams.n_head), + n_head_kv (hparams.n_head_kv), + n_embd_head (hparams.n_embd_head()), + n_embd_gqa (hparams.n_embd_gqa()), + n_expert (hparams.n_expert), + n_expert_used (hparams.n_expert_used), + freq_base (cparams.rope_freq_base), + freq_scale (cparams.rope_freq_scale), + ext_factor (cparams.yarn_ext_factor), + attn_factor (cparams.yarn_attn_factor), + beta_fast (cparams.yarn_beta_fast), + beta_slow (cparams.yarn_beta_slow), + norm_eps (hparams.f_norm_eps), + norm_rms_eps (hparams.f_norm_rms_eps), + n_tokens (batch.n_tokens), + n_kv (worst_case ? n_ctx : kv_self.n), + kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), + n_orig_ctx (cparams.n_yarn_orig_ctx), + do_rope_shift (worst_case || kv_self.has_shift), + cb (cb), + buf_compute_meta (lctx.buf_compute_meta) { GGML_ASSERT(!!kv_self.ctx); // all initializations should be done in init() @@ -4257,8 +4201,8 @@ struct llm_build_context { void init() { struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, + /*.mem_size =*/ buf_compute_meta.size(), + /*.mem_buffer =*/ buf_compute_meta.data(), /*.no_alloc =*/ true, }; @@ -5737,8 +5681,8 @@ static const std::unordered_map k_offload_map { "pos_embd", OFFLOAD_FUNC_NR }, { "inp_pos", OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope) - { "Q_scale", OFFLOAD_FUNC_FRC }, - { "KQ_scale", OFFLOAD_FUNC_FRC }, + { "Q_scale", OFFLOAD_FUNC_NOP }, + { "KQ_scale", OFFLOAD_FUNC_NOP }, { "KQ_mask", OFFLOAD_FUNC_FRC }, { "K_shift", OFFLOAD_FUNC_FRC }, @@ -5845,7 +5789,7 @@ static struct ggml_cgraph * llama_build_graph( bool alloc_inp_KQ_mask = false; bool alloc_inp_K_shift = false; -#ifdef GGML_USE_CUBLAS +#if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST) const bool do_offload = true; #else const bool do_offload = true; // TODO: set to false after finishing refactoring @@ -5873,7 +5817,7 @@ static struct ggml_cgraph * llama_build_graph( if (!ggml_allocr_is_measure(lctx.alloc) && batch.token) { const int64_t n_tokens = cur->ne[0]; - memcpy(cur->data, batch.token, n_tokens*ggml_element_size(cur)); + ggml_backend_tensor_set(cur, batch.token, 0, n_tokens*ggml_element_size(cur)); } alloc_inp_tokens = true; @@ -5886,7 +5830,7 @@ static struct ggml_cgraph * llama_build_graph( const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; - memcpy(cur->data, batch.embd, n_tokens*n_embd*ggml_element_size(cur)); + ggml_backend_tensor_set(cur, batch.embd, 0, n_tokens*n_embd*ggml_element_size(cur)); } alloc_inp_embd = true; @@ -5898,11 +5842,8 @@ static struct ggml_cgraph * llama_build_graph( if (!ggml_allocr_is_measure(lctx.alloc) && batch.pos) { const int64_t n_tokens = cur->ne[0]; - int32_t * data = (int32_t *) cur->data; - - for (int i = 0; i < n_tokens; ++i) { - data[i] = batch.pos[i]; - } + static_assert(std::is_same::value, "llama_pos must be int32_t"); + ggml_backend_tensor_set(cur, batch.pos, 0, n_tokens*ggml_element_size(cur)); } alloc_inp_pos = true; @@ -5913,7 +5854,8 @@ static struct ggml_cgraph * llama_build_graph( if (!ggml_allocr_is_measure(lctx.alloc)) { const int64_t n_embd_head = model.hparams.n_embd_head(); - ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head))); + float f = 1.0f/sqrtf(float(n_embd_head)); + ggml_backend_tensor_set(cur, &f, 0, sizeof(f)); } alloc_inp_Q_scale = true; @@ -5924,13 +5866,15 @@ static struct ggml_cgraph * llama_build_graph( if (!ggml_allocr_is_measure(lctx.alloc)) { const int64_t n_embd_head = model.hparams.n_embd_head(); + float f; if (model.arch == LLM_ARCH_PHI2) { // with phi2, we scale the Q to avoid precision issues // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66 - ggml_set_f32(cur, 1.0f); + f = 1.0f; } else { - ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head))); + f = 1.0f/sqrtf(float(n_embd_head)); } + ggml_backend_tensor_set(cur, &f, 0, sizeof(f)); } alloc_inp_KQ_scale = true; @@ -5943,8 +5887,13 @@ static struct ggml_cgraph * llama_build_graph( const int64_t n_kv = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; - float * data = (float *) cur->data; - memset(data, 0, ggml_nbytes(cur)); + float * data; + if (ggml_backend_buffer_is_host(cur->buffer)) { + data = (float *) cur->data; + } else { + lctx.buf_copy.resize(ggml_nbytes(cur)); + data = (float *) lctx.buf_copy.data(); + } for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { @@ -5952,12 +5901,20 @@ static struct ggml_cgraph * llama_build_graph( const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { + float f; if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + f = -INFINITY; + } else { + f = 0; } + data[h*(n_kv*n_tokens) + j*n_kv + i] = f; } } } + + if (data != cur->data) { + ggml_backend_tensor_set(cur, data, 0, ggml_nbytes(cur)); + } } alloc_inp_KQ_mask = true; @@ -5969,11 +5926,21 @@ static struct ggml_cgraph * llama_build_graph( if (!ggml_allocr_is_measure(lctx.alloc)) { const int64_t n_ctx = cur->ne[0]; - int32_t * data = (int32_t *) cur->data; + int32_t * data; + if (ggml_backend_buffer_is_host(cur->buffer)) { + data = (int32_t *) cur->data; + } else { + lctx.buf_copy.resize(ggml_nbytes(cur)); + data = (int32_t *) lctx.buf_copy.data(); + } for (int i = 0; i < n_ctx; ++i) { data[i] = lctx.kv_self.cells[i].delta; } + + if (data != cur->data) { + ggml_backend_tensor_set(cur, data, 0, ggml_nbytes(cur)); + } } alloc_inp_K_shift = true; @@ -6010,7 +5977,7 @@ static struct ggml_cgraph * llama_build_graph( static const std::unordered_map> k_offload_func_name = { { OFFLOAD_FUNC_NOP, "CPU" }, { OFFLOAD_FUNC_OUT, "CPU" }, -#ifdef GGML_USE_CUBLAS +#if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST) { OFFLOAD_FUNC, "GPU (CUDA)" }, { OFFLOAD_FUNC_FRC, "GPU (CUDA) FRC" }, { OFFLOAD_FUNC_KQV, "GPU (CUDA) KQV" }, @@ -6083,7 +6050,7 @@ static struct ggml_cgraph * llama_build_graph( offload_func_t func = ggml_offload_nop; // this is needed for compatibility with Metal for example -#ifdef GGML_USE_CUBLAS +#if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST) static offload_func_t ggml_offload_gpu = ggml_cuda_assign_buffers_no_alloc; #else static offload_func_t ggml_offload_gpu = ggml_offload_nop; @@ -6305,11 +6272,12 @@ static int llama_decode_internal( GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); } -#ifdef GGML_USE_CUBLAS +#if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST) + char * buf_alloc_base = (char *)ggml_backend_buffer_get_base(lctx.buf_alloc); for (int i = 0; i < gf->n_leafs; i++) { ggml_tensor * node = gf->leafs[i]; if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { - ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); + ggml_cuda_assign_scratch_offset(node, (char *)node->data - buf_alloc_base); ggml_cuda_copy_to_device(node); } } @@ -6317,7 +6285,7 @@ static int llama_decode_internal( for (int i = 0; i < gf->n_nodes; i++) { ggml_tensor * node = gf->nodes[i]; if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { - ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); + ggml_cuda_assign_scratch_offset(node, (char *)node->data - buf_alloc_base); } } @@ -6344,23 +6312,23 @@ static int llama_decode_internal( n_threads = 1; } -#if GGML_USE_MPI +#ifdef GGML_USE_MPI const int64_t n_layer = hparams.n_layer; ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); #endif #ifdef GGML_USE_METAL - if (lctx.ctx_metal) { - ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); - ggml_metal_graph_compute(lctx.ctx_metal, gf); - } else { - ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); + if (ggml_backend_is_metal(lctx.backend)) { + ggml_backend_metal_set_n_cb(lctx.backend, n_threads); } -#else - ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); #endif -#if GGML_USE_MPI + if (ggml_backend_is_cpu(lctx.backend)) { + ggml_backend_cpu_set_n_threads(lctx.backend, n_threads); + } + ggml_backend_graph_compute(lctx.backend, gf); + +#ifdef GGML_USE_MPI ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); #endif @@ -6412,20 +6380,20 @@ static int llama_decode_internal( if (batch.logits[i] == 0) { continue; } - memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab); + ggml_backend_tensor_get(res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float)); #ifndef NDEBUG logits_valid[i] = true; #endif } } else if (lctx.logits_all) { logits_out.resize(n_vocab * n_tokens); - memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); + ggml_backend_tensor_get(res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float)); #ifndef NDEBUG std::fill(logits_valid.begin(), logits_valid.end(), true); #endif } else { logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab); + ggml_backend_tensor_get(res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float)); #ifndef NDEBUG logits_valid[0] = true; #endif @@ -6437,7 +6405,7 @@ static int llama_decode_internal( auto & embedding_out = lctx.embedding; embedding_out.resize(n_embd); - memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(n_tokens - 1)), sizeof(float)*n_embd); + ggml_backend_tensor_get(embeddings, embedding_out.data(), (n_embd*(n_tokens - 1))*sizeof(float), n_embd*sizeof(float)); } // measure the performance only for the single-token evals @@ -8395,12 +8363,6 @@ void llama_beam_search(llama_context * ctx, // quantization // -template -struct no_init { - T value; - no_init() { /* do nothing */ } -}; - struct quantize_state_internal { const llama_model & model; const llama_model_quantize_params * params; @@ -8643,9 +8605,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s #endif llama_model_loader ml(fname_inp, use_mmap, NULL); - if (ml.use_mmap) { - ml.mapping.reset(new llama_mmap(&ml.file, /* prefetch */ 0, ggml_is_numa())); - } + ml.init_mapping(false); // no prefetching? llama_model model; llm_load_arch(ml, model); @@ -8944,29 +8904,10 @@ static int llama_apply_lora_from_file_internal( // load base model std::unique_ptr ml; - unique_context base_ctx(nullptr, ggml_free); - std::vector base_buf; - if (path_base_model) { + if (path_base_model) { LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model); - ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/ NULL)); - - size_t ctx_size; - size_t mmapped_size; - ml->calc_sizes(ctx_size, mmapped_size); - - base_buf.resize(ctx_size); - - ggml_init_params base_params; - base_params.mem_size = base_buf.size(); - base_params.mem_buffer = base_buf.data(); - base_params.no_alloc = ml->use_mmap; - - base_ctx.reset(ggml_init(base_params)); - - // maybe this should be in llama_model_loader - if (ml->use_mmap) { - ml->mapping.reset(new llama_mmap(&ml->file, /* prefetch */ 0, ggml_is_numa())); - } + ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/ nullptr)); + ml->init_mapping(false); // no prefetching } // read tensors and apply @@ -9058,7 +8999,7 @@ static int llama_apply_lora_from_file_internal( offload_func_t offload_func = ggml_offload_nop; offload_func_t offload_func_force_inplace = ggml_offload_nop; -#ifdef GGML_USE_CUBLAS +#if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST) if (dest_t->backend == GGML_BACKEND_GPU || dest_t->backend == GGML_BACKEND_GPU_SPLIT) { if (dest_t->type != GGML_TYPE_F16) { throw std::runtime_error(format( @@ -9079,7 +9020,7 @@ static int llama_apply_lora_from_file_internal( return 1; } - base_t = ml->create_tensor(base_ctx.get(), base_name, { dest_t->ne[0], dest_t->ne[1] }, GGML_BACKEND_CPU); + base_t = ml->get_tensor_meta(base_name.c_str()); ml->load_data_for(base_t); } else { base_t = dest_t; @@ -9364,7 +9305,39 @@ struct llama_context * llama_new_context_with_model( // reserve memory for context buffers if (!hparams.vocab_only) { - if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, type_k, type_v, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) { + // initialize backend +#ifdef GGML_USE_METAL + if (model->n_gpu_layers > 0) { + ctx->backend = ggml_backend_metal_init(); + if (ctx->backend == nullptr) { + LLAMA_LOG_ERROR("%s: failed to initialize Metal backend\n", __func__); + } + } +#elif defined(GGML_USE_CUBLAS) && defined(LLAMA_GGML_BACKEND_CUDA_TEST) + // for testing only + if (model->n_gpu_layers > 0) { + ctx->backend = ggml_backend_cuda_init(0); + if (ctx->backend == nullptr) { + LLAMA_LOG_ERROR("%s: failed to initialize CUDA backend\n", __func__); + } + } +#endif + + if (ctx->backend == nullptr && ggml_backend_buffer_is_host(model->buf)) { + ctx->backend = ggml_backend_cpu_init(); + if (ctx->backend == nullptr) { + LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__); + } + } + + if (ctx->backend == nullptr) { + LLAMA_LOG_ERROR("%s: failed to initialize a backend\n", __func__); + delete ctx; + return nullptr; + } + + if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, type_k, type_v, + cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -9400,12 +9373,11 @@ struct llama_context * llama_new_context_with_model( } { - static const size_t tensor_alignment = 32; // the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data - ctx->buf_compute.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead()); + ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead()); // create measure allocator - ctx->alloc = ggml_allocr_new_measure(tensor_alignment); + ctx->alloc = ggml_allocr_new_measure_from_backend(ctx->backend); // build worst-case graph int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch); @@ -9413,98 +9385,50 @@ struct llama_context * llama_new_context_with_model( llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0)); -#ifdef GGML_USE_METAL - if (model->n_gpu_layers > 0) { - ctx->ctx_metal = ggml_metal_init(1); - if (!ctx->ctx_metal) { - LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__); - llama_free(ctx); - return NULL; - } - //ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false); - //ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); - } -#endif // measure memory requirements for the graph - size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment; + size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf); - LLAMA_LOG_INFO("%s: compute buffer total size = %.2f MiB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); + LLAMA_LOG_INFO("%s: compute buffer total size = %.2f MiB\n", __func__, (ctx->buf_compute_meta.size() + alloc_size) / 1024.0 / 1024.0); - // recreate allocator with exact memory requirements + // create allocator again with exact memory requirements ggml_allocr_free(ctx->alloc); - ctx->buf_alloc.resize(alloc_size); - ctx->alloc = ggml_allocr_new(ctx->buf_alloc.data, ctx->buf_alloc.size, tensor_alignment); -#ifdef GGML_USE_METAL - if (ctx->ctx_metal) { - //ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); - } -#endif -#ifdef GGML_USE_CUBLAS - ggml_cuda_set_scratch_size(alloc_size); - LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MiB\n", __func__, alloc_size / 1024.0 / 1024.0); + ctx->buf_alloc = ggml_backend_alloc_buffer(ctx->backend, alloc_size); + ctx->alloc = ggml_allocr_new_from_buffer(ctx->buf_alloc); +#if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST) + if (model->n_gpu_layers > 0) { + ggml_cuda_set_scratch_size(alloc_size); + LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MiB\n", __func__, alloc_size / 1024.0 / 1024.0); - // calculate total VRAM usage - auto add_tensor = [](const ggml_tensor * t, size_t & size) { - if (t->backend == GGML_BACKEND_GPU || t->backend == GGML_BACKEND_GPU_SPLIT) { - size += ggml_nbytes(t); + // calculate total VRAM usage + auto add_tensor = [](const ggml_tensor * t, size_t & size) { + if (t->backend == GGML_BACKEND_GPU || t->backend == GGML_BACKEND_GPU_SPLIT) { + size += ggml_nbytes(t); + } + }; + size_t model_vram_size = 0; + for (const auto & kv : model->tensors_by_name) { + add_tensor(kv.second, model_vram_size); } - }; - size_t model_vram_size = 0; - for (const auto & kv : model->tensors_by_name) { - add_tensor(kv.second, model_vram_size); - } - size_t kv_vram_size = 0; - for (auto & k : ctx->kv_self.k_l) { - add_tensor(k, kv_vram_size); - } - for (auto & v : ctx->kv_self.v_l) { - add_tensor(v, kv_vram_size); - } - - size_t ctx_vram_size = alloc_size + kv_vram_size; - size_t total_vram_size = model_vram_size + ctx_vram_size; - - LLAMA_LOG_INFO("%s: total VRAM used: %.2f MiB (model: %.2f MiB, context: %.2f MiB)\n", __func__, - total_vram_size / 1024.0 / 1024.0, - model_vram_size / 1024.0 / 1024.0, - ctx_vram_size / 1024.0 / 1024.0); -#endif - } - -#ifdef GGML_USE_METAL - if (model->n_gpu_layers > 0) { - // this allocates all Metal resources and memory buffers - - void * data_ptr = NULL; - size_t data_size = 0; - - if (ctx->model.mapping) { - data_ptr = ctx->model.mapping->addr; - data_size = ctx->model.mapping->size; - } else { - data_ptr = ggml_get_mem_buffer(ctx->model.ctx); - data_size = ggml_get_mem_size (ctx->model.ctx); - } - - const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); + size_t kv_vram_size = 0; + for (auto & k : ctx->kv_self.k_l) { + add_tensor(k, kv_vram_size); + } + for (auto & v : ctx->kv_self.v_l) { + add_tensor(v, kv_vram_size); + } - LLAMA_LOG_INFO("%s: max tensor size = %8.2f MiB\n", __func__, max_size/1024.0/1024.0); + size_t ctx_vram_size = alloc_size + kv_vram_size; + size_t total_vram_size = model_vram_size + ctx_vram_size; -#define LLAMA_METAL_CHECK_BUF(result) \ - if (!(result)) { \ - LLAMA_LOG_ERROR("%s: failed to add buffer\n", __func__); \ - llama_free(ctx); \ - return NULL; \ + LLAMA_LOG_INFO("%s: total VRAM used: %.2f MiB (model: %.2f MiB, context: %.2f MiB)\n", __func__, + total_vram_size / 1024.0 / 1024.0, + model_vram_size / 1024.0 / 1024.0, + ctx_vram_size / 1024.0 / 1024.0); } - - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size)); - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0)); - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.data, ctx->buf_alloc.size, 0)); -#undef LLAMA_METAL_CHECK_BUF - } #endif + } } #ifdef GGML_USE_MPI @@ -9796,7 +9720,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) { const size_t s_embedding = ctx->embedding.size() * sizeof(float); const size_t s_kv_size = sizeof(size_t); const size_t s_kv_ntok = sizeof(int); - const size_t s_kv = ctx->kv_self.buf.size; + const size_t s_kv = ggml_backend_buffer_get_size(ctx->kv_self.buf); const size_t s_total = ( + s_rng_size @@ -9924,7 +9848,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat const auto n_embd = hparams.n_embd_gqa(); const auto n_ctx = cparams.n_ctx; - const size_t kv_buf_size = kv_self.buf.size; + const size_t kv_buf_size = ggml_backend_buffer_get_size(kv_self.buf); const uint32_t kv_head = kv_self.head; const uint32_t kv_size = kv_self.size; const uint32_t kv_used = kv_self.used; @@ -9940,17 +9864,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat ggml_context * cpy_ctx = ggml_init({ 6*n_layer*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true }); ggml_cgraph * gf = ggml_new_graph(cpy_ctx); - std::vector> kout2d_data(n_layer); - std::vector> vout2d_data(n_layer); + std::vector kout2d(n_layer); + std::vector vout2d(n_layer); for (int il = 0; il < (int) n_layer; ++il) { - ggml_tensor * kout2d = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head); - kout2d_data[il].resize(ggml_nbytes(kout2d)); - kout2d->data = kout2d_data[il].data(); - - ggml_tensor * vout2d = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd); - vout2d_data[il].resize(ggml_nbytes(vout2d)); - vout2d->data = vout2d_data[il].data(); + kout2d[il] = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head); + vout2d[il] = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd); ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il], n_embd, kv_head, @@ -9960,20 +9879,28 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat kv_head, n_embd, elt_size*n_ctx, 0); - ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k2d, kout2d)); - ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v2d, vout2d)); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k2d, kout2d[il])); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v2d, vout2d[il])); } - ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1); + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(cpy_ctx, ctx->backend); - ggml_free(cpy_ctx); + ggml_backend_graph_compute(ctx->backend, gf); + + std::vector tmp_buf; + for (int il = 0; il < (int) n_layer; ++il) { + tmp_buf.resize(ggml_nbytes(kout2d[il])); + ggml_backend_tensor_get(kout2d[il], tmp_buf.data(), 0, tmp_buf.size()); + data_ctx->write(tmp_buf.data(), tmp_buf.size()); - // our data is now in the kout2d_data and vout2d_data buffers - // write them to file - for (uint32_t il = 0; il < n_layer; ++il) { - data_ctx->write(kout2d_data[il].data(), kout2d_data[il].size()); - data_ctx->write(vout2d_data[il].data(), vout2d_data[il].size()); + tmp_buf.resize(ggml_nbytes(vout2d[il])); + ggml_backend_tensor_get(vout2d[il], tmp_buf.data(), 0, tmp_buf.size()); + data_ctx->write(tmp_buf.data(), tmp_buf.size()); } + + ggml_free(cpy_ctx); + + ggml_backend_buffer_free(buf); } for (uint32_t i = 0; i < kv_size; ++i) { @@ -10071,21 +9998,19 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used); if (kv_buf_size) { - GGML_ASSERT(kv_self.buf.size == kv_buf_size); + GGML_ASSERT(ggml_backend_buffer_get_size(kv_self.buf) == kv_buf_size); const size_t elt_size = ggml_element_size(kv_self.k_l[0]); ggml_context * cpy_ctx = ggml_init({ 6*n_layer*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true }); ggml_cgraph * gf = ggml_new_graph(cpy_ctx); - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * kin2d = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head); - kin2d->data = (void *) inp; - inp += ggml_nbytes(kin2d); + std::vector kin2d(n_layer); + std::vector vin2d(n_layer); - ggml_tensor * vin2d = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd); - vin2d->data = (void *) inp; - inp += ggml_nbytes(vin2d); + for (int il = 0; il < n_layer; ++il) { + kin2d[il] = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head); + vin2d[il] = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd); ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il], n_embd, kv_head, @@ -10095,13 +10020,26 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { kv_head, n_embd, elt_size*n_ctx, 0); - ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin2d, k2d)); - ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin2d, v2d)); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin2d[il], k2d)); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin2d[il], v2d)); } - ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1); + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(cpy_ctx, ctx->backend); + + // load data into the tensors + for (int il = 0; il < n_layer; ++il) { + ggml_backend_tensor_set(kin2d[il], inp, 0, ggml_nbytes(kin2d[il])); + inp += ggml_nbytes(kin2d[il]); + + ggml_backend_tensor_set(vin2d[il], inp, 0, ggml_nbytes(vin2d[il])); + inp += ggml_nbytes(vin2d[il]); + } + + ggml_backend_graph_compute(ctx->backend, gf); ggml_free(cpy_ctx); + + ggml_backend_buffer_free(buf); } ctx->kv_self.head = kv_head;