From bc5e9fc63274ea1e743eedee8a83d802a87c4482 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 25 May 2024 10:47:15 +0300 Subject: [PATCH] Revert "whisper : remove extra backend instance (huh?)" This reverts commit 4caa64b73ed4c0e71097c865b0f6a9c136b007c6. --- whisper.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 84aec8238cd..7b8c683fca7 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -818,6 +818,8 @@ struct whisper_state { whisper_decoder decoders[WHISPER_MAX_DECODERS]; + ggml_backend_t backend = nullptr; + // ggml-alloc: // - stores meta info about the intermediate tensors into the `meta` buffers // - stores the actual tensor data into the `data` buffers @@ -2261,7 +2263,7 @@ static bool whisper_encode_internal( } if (!whisper_encode_external(wstate)) { - if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { return false; } } else { @@ -2284,7 +2286,7 @@ static bool whisper_encode_internal( return false; } - if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { return false; } } @@ -2300,7 +2302,7 @@ static bool whisper_encode_internal( return false; } - if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { return false; } } @@ -2801,7 +2803,7 @@ static bool whisper_decode_internal( logits = gf->nodes[gf->n_nodes - 1]; - if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { return false; } } @@ -3248,6 +3250,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { whisper_state * state = new whisper_state; + state->backend = whisper_backend_init(ctx->params); + if (!state->backend) { + WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__); + whisper_free_state(state); + return nullptr; + } + // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx // in theory, there can be a case where this is not enough, but in practice it should always be enough const int factor = 3; @@ -3684,6 +3693,8 @@ void whisper_free_state(struct whisper_state * state) { ggml_gallocr_free(state->alloc_cross.alloc); ggml_gallocr_free(state->alloc_decode.alloc); + ggml_backend_free(state->backend); + // [EXPERIMENTAL] Token-level timestamps with DTW aheads_masks_free(state->aheads_masks);