Skip to content

Commit

Permalink
Improve decoding (ggerganov#291)
Browse files Browse the repository at this point in the history
* whisper : prepare infra for new decoding strategies

* whisper : apply logit filters and compute logprobs

* whisper : add whisper_get_logits()

* whisper : separate self and cross attention memory

Initial step needed for supporting parallel decoders

* whisper : move probs_id buffer to whisper_context

* whisper : refactor kv cache into separate struct

* whisper : move self-attention kv cache to whisper_decoder

* whisper : wip decoding parameters + strategies

* whisper : wip decoding parameters + strategies (part 2)

* whisper : wip decoding parameters + strategies (part 3)

* whisper : wip decoding parameters + strategies (part 4)

* whisper : fix prompt_past update to not include prompt_init

* whisper : temperature + best_of support

* whisper : support for compression_ration_threshold

We actually use entropy, but it is similar

* command : fix example to use logits instead of obsolete probs

* whisper : handle empty sequence ranking

* whisper : add WHISPER_DEBUG + diagnostic prints + new main args

* whisper : minor fixes

* whisper : add beam-search support

* whisper : bug fix when there no previous context

* whisper : add comments

* stream : disable temperature fallback

For real-time processing, we always want a single decoder running at T=0

* whisper.swiftui : update example - fix paths + add empty folders
  • Loading branch information
ggerganov authored Jan 15, 2023
1 parent 600f764 commit 87191b6
Show file tree
Hide file tree
Showing 11 changed files with 1,539 additions and 792 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ build/
build-em/
build-debug/
build-release/
build-static/
build-sanitize-addr/
build-sanitize-thread/

Expand All @@ -18,6 +19,7 @@ build-sanitize-thread/
/bench

sync.sh
libwhisper.a
libwhisper.so
compile_commands.json

Expand Down
12 changes: 1 addition & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,17 +212,7 @@ make large
## Limitations

- Inference only
- No GPU support
- Very basic greedy sampling scheme - always pick up the token with highest probability.
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
to run the python code with the following parameters:

```
whisper --best_of None --beam_size None ...
```

In the future, `whisper.cpp` will support more sampling strategies.
- No GPU support (yet)

## Another example

Expand Down
107 changes: 66 additions & 41 deletions examples/command/command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,56 +671,81 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
break;
}

const auto * probs = whisper_get_probs(ctx);
std::vector<std::pair<float, int>> probs_id;

double psum = 0.0;
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
probs_id.back().first += probs[allowed_tokens[i][j]];
}
probs_id.back().first /= allowed_tokens[i].size();
psum += probs_id.back().first;
}
// estimate command probability
// NOTE: not optimal
{
const auto * logits = whisper_get_logits(ctx);

// normalize
for (auto & p : probs_id) {
p.first /= psum;
}
std::vector<float> probs(whisper_n_vocab(ctx), 0.0f);

// sort descending
{
using pair_type = decltype(probs_id)::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
// compute probs from logits via softmax
{
float max = -1e9;
for (int i = 0; i < (int) probs.size(); ++i) {
max = std::max(max, logits[i]);
}

// print the commands and the respective probabilities
{
fprintf(stdout, "\n");
for (const auto & cmd : probs_id) {
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
for (int token : allowed_tokens[cmd.second]) {
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
float sum = 0.0f;
for (int i = 0; i < (int) probs.size(); ++i) {
probs[i] = expf(logits[i] - max);
sum += probs[i];
}

for (int i = 0; i < (int) probs.size(); ++i) {
probs[i] /= sum;
}
}

std::vector<std::pair<float, int>> probs_id;

double psum = 0.0;
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
probs_id.back().first += probs[allowed_tokens[i][j]];
}
probs_id.back().first /= allowed_tokens[i].size();
psum += probs_id.back().first;
}

// normalize
for (auto & p : probs_id) {
p.first /= psum;
}

// sort descending
{
using pair_type = decltype(probs_id)::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}

// print the commands and the respective probabilities
{
fprintf(stdout, "\n");
for (const auto & cmd : probs_id) {
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
for (int token : allowed_tokens[cmd.second]) {
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
}
fprintf(stdout, "\n");
}
}
}

// best command
{
const auto t_end = std::chrono::high_resolution_clock::now();
// best command
{
const auto t_end = std::chrono::high_resolution_clock::now();

const float prob = probs_id[0].first;
const int index = probs_id[0].second;
const float prob = probs_id[0].first;
const int index = probs_id[0].second;

fprintf(stdout, "\n");
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
fprintf(stdout, "\n");
fprintf(stdout, "\n");
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
fprintf(stdout, "\n");
}
}

audio.clear();
Expand Down
91 changes: 55 additions & 36 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,12 @@ struct whisper_params {
int32_t duration_ms = 0;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = 5;
int32_t beam_size = -1;

float word_thold = 0.01f;
float word_thold = 0.01f;
float entropy_thold = 2.4f;
float logprob_thold = -1.0f;

bool speed_up = false;
bool translate = false;
Expand Down Expand Up @@ -104,7 +108,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
Expand Down Expand Up @@ -136,31 +144,35 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, "\n");
}

Expand Down Expand Up @@ -235,7 +247,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);

const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));

printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
Expand Down Expand Up @@ -331,20 +343,19 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
if (text[0] == ' ')
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
if (text[0] == ' ') {
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
}
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << ", "
<< 10 * t1 << ", \""
<< text << "\"\n";

//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n";
}

return true;
}


// karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments
Expand Down Expand Up @@ -620,6 +631,8 @@ int main(int argc, char ** argv) {
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);

wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;

wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
Expand All @@ -633,12 +646,18 @@ int main(int argc, char ** argv) {

wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;

wparams.speed_up = params.speed_up;

wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.temperature_inc = -1;

wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();

whisper_print_user_data user_data = { &params, &pcmf32s };

Expand Down
3 changes: 3 additions & 0 deletions examples/stream.wasm/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ void stream_main(size_t index) {
wparams.max_tokens = 32;
wparams.audio_ctx = 768; // partial encoder context for better performance

// disable temperature fallback
wparams.temperature_inc = -1.0f;

wparams.language = "en";

printf("stream: using %d threads\n", wparams.n_threads);
Expand Down
3 changes: 3 additions & 0 deletions examples/stream/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,9 @@ int main(int argc, char ** argv) {
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;

// disable temperature fallback
wparams.temperature_inc = -1.0f;

wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();

Expand Down
Empty file.
Empty file.
Loading

0 comments on commit 87191b6

Please sign in to comment.