Skip to content

whisper : add batched decoding #1486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions examples/bench/bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) {
}
// heat encoder
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
fprintf(stderr, "error: failed to encode: %d\n", ret);
return 4;
}

Expand All @@ -90,34 +90,44 @@ int whisper_bench_full(const whisper_params & params) {

// prompt heat
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}

// text-generation heat
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}

whisper_reset_timings(ctx);

// actual run
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
fprintf(stderr, "error: failed to encode: %d\n", ret);
return 4;
}

for (int i = 0; i < 16; i++) {
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
// text-generation
for (int i = 0; i < 256; i++) {
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
}

for (int i = 0; i < 256; i++) {
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
// batched decoding
for (int i = 0; i < 64; i++) {
if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
}

// prompt processing
for (int i = 0; i < 16; i++) {
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
}
Expand Down
8 changes: 4 additions & 4 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ struct whisper_params {
int32_t progress_step = 5;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = 2;
int32_t beam_size = -1;
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;

float word_thold = 0.01f;
float entropy_thold = 2.40f;
Expand Down Expand Up @@ -925,9 +925,9 @@ int main(int argc, char ** argv) {
if (params.detect_language) {
params.language = "auto";
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.n_threads, params.n_processors, params.beam_size, params.best_of,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.tinydiarize ? "tdrz = 1, " : "",
Expand Down
7 changes: 4 additions & 3 deletions extra/bench-all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
printf "\n"
fi

printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit"
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"

for model in "${models[@]}"; do
# actual run
Expand All @@ -56,6 +56,7 @@ for model in "${models[@]}"; do
# parse the output:
encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
batchd_time=$(echo "$output" | grep "batchd time" | awk '{print $11}')
prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
system_info=$(echo "$output" | grep "system_info")
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
Expand Down Expand Up @@ -94,6 +95,6 @@ for model in "${models[@]}"; do
commit=$(git rev-parse --short HEAD)

if [ $ret -eq 0 ]; then
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
fi
done
Loading