Skip to content

Commit abd42ba

Browse files
authored
Improve decoding (ggml-org#291)
* whisper : prepare infra for new decoding strategies * whisper : apply logit filters and compute logprobs * whisper : add whisper_get_logits() * whisper : separate self and cross attention memory Initial step needed for supporting parallel decoders * whisper : move probs_id buffer to whisper_context * whisper : refactor kv cache into separate struct * whisper : move self-attention kv cache to whisper_decoder * whisper : wip decoding parameters + strategies * whisper : wip decoding parameters + strategies (part 2) * whisper : wip decoding parameters + strategies (part 3) * whisper : wip decoding parameters + strategies (part 4) * whisper : fix prompt_past update to not include prompt_init * whisper : temperature + best_of support * whisper : support for compression_ration_threshold We actually use entropy, but it is similar * command : fix example to use logits instead of obsolete probs * whisper : handle empty sequence ranking * whisper : add WHISPER_DEBUG + diagnostic prints + new main args * whisper : minor fixes * whisper : add beam-search support * whisper : bug fix when there no previous context * whisper : add comments * stream : disable temperature fallback For real-time processing, we always want a single decoder running at T=0 * whisper.swiftui : update example - fix paths + add empty folders
1 parent 0cc5b6f commit abd42ba

File tree

11 files changed

+1539
-792
lines changed

11 files changed

+1539
-792
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ build/
88
build-em/
99
build-debug/
1010
build-release/
11+
build-static/
1112
build-sanitize-addr/
1213
build-sanitize-thread/
1314

@@ -18,6 +19,7 @@ build-sanitize-thread/
1819
/bench
1920

2021
sync.sh
22+
libwhisper.a
2123
libwhisper.so
2224
compile_commands.json
2325

README.md

+1-11
Original file line numberDiff line numberDiff line change
@@ -212,17 +212,7 @@ make large
212212
## Limitations
213213

214214
- Inference only
215-
- No GPU support
216-
- Very basic greedy sampling scheme - always pick up the token with highest probability.
217-
This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274)
218-
from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure
219-
to run the python code with the following parameters:
220-
221-
```
222-
whisper --best_of None --beam_size None ...
223-
```
224-
225-
In the future, `whisper.cpp` will support more sampling strategies.
215+
- No GPU support (yet)
226216

227217
## Another example
228218

examples/command/command.cpp

+66-41
Original file line numberDiff line numberDiff line change
@@ -671,56 +671,81 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
671671
break;
672672
}
673673

674-
const auto * probs = whisper_get_probs(ctx);
675-
std::vector<std::pair<float, int>> probs_id;
676-
677-
double psum = 0.0;
678-
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
679-
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
680-
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
681-
probs_id.back().first += probs[allowed_tokens[i][j]];
682-
}
683-
probs_id.back().first /= allowed_tokens[i].size();
684-
psum += probs_id.back().first;
685-
}
674+
// estimate command probability
675+
// NOTE: not optimal
676+
{
677+
const auto * logits = whisper_get_logits(ctx);
686678

687-
// normalize
688-
for (auto & p : probs_id) {
689-
p.first /= psum;
690-
}
679+
std::vector<float> probs(whisper_n_vocab(ctx), 0.0f);
691680

692-
// sort descending
693-
{
694-
using pair_type = decltype(probs_id)::value_type;
695-
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
696-
return a.first > b.first;
697-
});
698-
}
681+
// compute probs from logits via softmax
682+
{
683+
float max = -1e9;
684+
for (int i = 0; i < (int) probs.size(); ++i) {
685+
max = std::max(max, logits[i]);
686+
}
699687

700-
// print the commands and the respective probabilities
701-
{
702-
fprintf(stdout, "\n");
703-
for (const auto & cmd : probs_id) {
704-
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
705-
for (int token : allowed_tokens[cmd.second]) {
706-
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
688+
float sum = 0.0f;
689+
for (int i = 0; i < (int) probs.size(); ++i) {
690+
probs[i] = expf(logits[i] - max);
691+
sum += probs[i];
692+
}
693+
694+
for (int i = 0; i < (int) probs.size(); ++i) {
695+
probs[i] /= sum;
707696
}
697+
}
698+
699+
std::vector<std::pair<float, int>> probs_id;
700+
701+
double psum = 0.0;
702+
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
703+
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
704+
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
705+
probs_id.back().first += probs[allowed_tokens[i][j]];
706+
}
707+
probs_id.back().first /= allowed_tokens[i].size();
708+
psum += probs_id.back().first;
709+
}
710+
711+
// normalize
712+
for (auto & p : probs_id) {
713+
p.first /= psum;
714+
}
715+
716+
// sort descending
717+
{
718+
using pair_type = decltype(probs_id)::value_type;
719+
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
720+
return a.first > b.first;
721+
});
722+
}
723+
724+
// print the commands and the respective probabilities
725+
{
708726
fprintf(stdout, "\n");
727+
for (const auto & cmd : probs_id) {
728+
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
729+
for (int token : allowed_tokens[cmd.second]) {
730+
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
731+
}
732+
fprintf(stdout, "\n");
733+
}
709734
}
710-
}
711735

712-
// best command
713-
{
714-
const auto t_end = std::chrono::high_resolution_clock::now();
736+
// best command
737+
{
738+
const auto t_end = std::chrono::high_resolution_clock::now();
715739

716-
const float prob = probs_id[0].first;
717-
const int index = probs_id[0].second;
740+
const float prob = probs_id[0].first;
741+
const int index = probs_id[0].second;
718742

719-
fprintf(stdout, "\n");
720-
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
721-
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
722-
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
723-
fprintf(stdout, "\n");
743+
fprintf(stdout, "\n");
744+
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
745+
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
746+
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
747+
fprintf(stdout, "\n");
748+
}
724749
}
725750

726751
audio.clear();

examples/main/main.cpp

+55-36
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,12 @@ struct whisper_params {
5959
int32_t duration_ms = 0;
6060
int32_t max_context = -1;
6161
int32_t max_len = 0;
62+
int32_t best_of = 5;
63+
int32_t beam_size = -1;
6264

63-
float word_thold = 0.01f;
65+
float word_thold = 0.01f;
66+
float entropy_thold = 2.4f;
67+
float logprob_thold = -1.0f;
6468

6569
bool speed_up = false;
6670
bool translate = false;
@@ -104,7 +108,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
104108
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
105109
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
106110
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
111+
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
112+
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
107113
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
114+
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
115+
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
108116
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
109117
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
110118
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
@@ -136,31 +144,35 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
136144
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
137145
fprintf(stderr, "\n");
138146
fprintf(stderr, "options:\n");
139-
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
140-
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
141-
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
142-
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
143-
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
144-
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
145-
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
146-
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
147-
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
148-
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
149-
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
150-
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
151-
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
152-
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
153-
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
154-
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
155-
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
156-
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
157-
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
158-
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
159-
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
160-
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
161-
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
162-
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
163-
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
147+
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
148+
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
149+
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
150+
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
151+
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
152+
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
153+
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
154+
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
155+
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
156+
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
157+
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
158+
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
159+
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
160+
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
161+
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
162+
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
163+
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
164+
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
165+
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
166+
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
167+
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
168+
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
169+
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
170+
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
171+
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
172+
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
173+
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
174+
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
175+
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
164176
fprintf(stderr, "\n");
165177
}
166178

@@ -235,7 +247,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
235247
const char * text = whisper_full_get_token_text(ctx, i, j);
236248
const float p = whisper_full_get_token_p (ctx, i, j);
237249

238-
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
250+
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
239251

240252
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
241253
}
@@ -331,20 +343,19 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
331343
const int n_segments = whisper_full_n_segments(ctx);
332344
for (int i = 0; i < n_segments; ++i) {
333345
const char * text = whisper_full_get_segment_text(ctx, i);
334-
if (text[0] == ' ')
335-
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
346+
if (text[0] == ' ') {
347+
text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character.
348+
}
336349
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
337350
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
338-
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
339-
fout << 10 * t0 << ", "
340-
<< 10 * t1 << ", \""
341-
<< text << "\"\n";
351+
352+
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
353+
fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n";
342354
}
343355

344356
return true;
345357
}
346358

347-
348359
// karaoke video generation
349360
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
350361
// TODO: font parameter adjustments
@@ -620,6 +631,8 @@ int main(int argc, char ** argv) {
620631
{
621632
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
622633

634+
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
635+
623636
wparams.print_realtime = false;
624637
wparams.print_progress = params.print_progress;
625638
wparams.print_timestamps = !params.no_timestamps;
@@ -633,12 +646,18 @@ int main(int argc, char ** argv) {
633646

634647
wparams.token_timestamps = params.output_wts || params.max_len > 0;
635648
wparams.thold_pt = params.word_thold;
649+
wparams.entropy_thold = params.entropy_thold;
650+
wparams.logprob_thold = params.logprob_thold;
636651
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
637652

638653
wparams.speed_up = params.speed_up;
639654

640-
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
641-
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
655+
wparams.greedy.best_of = params.best_of;
656+
wparams.beam_search.beam_size = params.beam_size;
657+
wparams.temperature_inc = -1;
658+
659+
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
660+
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
642661

643662
whisper_print_user_data user_data = { &params, &pcmf32s };
644663

examples/stream.wasm/emscripten.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ void stream_main(size_t index) {
4949
wparams.max_tokens = 32;
5050
wparams.audio_ctx = 768; // partial encoder context for better performance
5151

52+
// disable temperature fallback
53+
wparams.temperature_inc = -1.0f;
54+
5255
wparams.language = "en";
5356

5457
printf("stream: using %d threads\n", wparams.n_threads);

examples/stream/stream.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,9 @@ int main(int argc, char ** argv) {
615615
wparams.audio_ctx = params.audio_ctx;
616616
wparams.speed_up = params.speed_up;
617617

618+
// disable temperature fallback
619+
wparams.temperature_inc = -1.0f;
620+
618621
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
619622
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
620623

examples/whisper.swiftui/whisper.swiftui.demo/Resources/models/.gitignore

Whitespace-only changes.

examples/whisper.swiftui/whisper.swiftui.demo/Resources/samples/.gitignore

Whitespace-only changes.

0 commit comments

Comments
 (0)