Skip to content

Commit ae1bd69

Browse files
committed
bench : add batch size 5 bench
1 parent 3ed9af3 commit ae1bd69

File tree

3 files changed

+37
-16
lines changed

3 files changed

+37
-16
lines changed

examples/bench/bench.cpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) {
8181
}
8282
// heat encoder
8383
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
84-
fprintf(stderr, "error: failed to encode model: %d\n", ret);
84+
fprintf(stderr, "error: failed to encode: %d\n", ret);
8585
return 4;
8686
}
8787

@@ -90,34 +90,44 @@ int whisper_bench_full(const whisper_params & params) {
9090

9191
// prompt heat
9292
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
93-
fprintf(stderr, "error: failed to encode model: %d\n", ret);
93+
fprintf(stderr, "error: failed to decode: %d\n", ret);
9494
return 4;
9595
}
9696

9797
// text-generation heat
9898
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
99-
fprintf(stderr, "error: failed to encode model: %d\n", ret);
99+
fprintf(stderr, "error: failed to decode: %d\n", ret);
100100
return 4;
101101
}
102102

103103
whisper_reset_timings(ctx);
104104

105105
// actual run
106106
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
107-
fprintf(stderr, "error: failed to encode model: %d\n", ret);
107+
fprintf(stderr, "error: failed to encode: %d\n", ret);
108108
return 4;
109109
}
110110

111-
for (int i = 0; i < 16; i++) {
112-
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
113-
fprintf(stderr, "error: failed to encode model: %d\n", ret);
111+
// text-generation
112+
for (int i = 0; i < 256; i++) {
113+
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
114+
fprintf(stderr, "error: failed to decode: %d\n", ret);
114115
return 4;
115116
}
116117
}
117118

118-
for (int i = 0; i < 256; i++) {
119-
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
120-
fprintf(stderr, "error: failed to encode model: %d\n", ret);
119+
// batched decoding
120+
for (int i = 0; i < 64; i++) {
121+
if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
122+
fprintf(stderr, "error: failed to decode: %d\n", ret);
123+
return 4;
124+
}
125+
}
126+
127+
// prompt processing
128+
for (int i = 0; i < 16; i++) {
129+
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
130+
fprintf(stderr, "error: failed to decode: %d\n", ret);
121131
return 4;
122132
}
123133
}

extra/bench-all.sh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
4444
printf "\n"
4545
fi
4646

47-
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
48-
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
47+
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit"
48+
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
4949

5050
for model in "${models[@]}"; do
5151
# actual run
@@ -56,6 +56,7 @@ for model in "${models[@]}"; do
5656
# parse the output:
5757
encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
5858
decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
59+
batchd_time=$(echo "$output" | grep "batchd time" | awk '{print $11}')
5960
prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
6061
system_info=$(echo "$output" | grep "system_info")
6162
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
@@ -94,6 +95,6 @@ for model in "${models[@]}"; do
9495
commit=$(git rev-parse --short HEAD)
9596

9697
if [ $ret -eq 0 ]; then
97-
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
98+
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
9899
fi
99100
done

whisper.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -773,13 +773,15 @@ struct whisper_state {
773773
int64_t t_sample_us = 0;
774774
int64_t t_encode_us = 0;
775775
int64_t t_decode_us = 0;
776+
int64_t t_batchd_us = 0;
776777
int64_t t_prompt_us = 0;
777778
int64_t t_mel_us = 0;
778779

779780
int32_t n_sample = 0; // number of tokens sampled
780781
int32_t n_encode = 0; // number of encoder calls
781-
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
782-
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
782+
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
783+
int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding)
784+
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
783785
int32_t n_fail_p = 0; // number of logprob threshold failures
784786
int32_t n_fail_h = 0; // number of entropy threshold failures
785787

@@ -2616,9 +2618,12 @@ static bool whisper_decode_internal(
26162618
if (batch.n_tokens == 1) {
26172619
wstate.t_decode_us += ggml_time_us() - t_start_us;
26182620
wstate.n_decode++;
2621+
} else if (batch.n_tokens < 16) {
2622+
wstate.t_batchd_us += ggml_time_us() - t_start_us;
2623+
wstate.n_batchd += n_tokens;
26192624
} else {
26202625
wstate.t_prompt_us += ggml_time_us() - t_start_us;
2621-
wstate.n_prompt++;
2626+
wstate.n_prompt += n_tokens;
26222627
}
26232628

26242629
return !(abort_callback && abort_callback(abort_callback_data));
@@ -3827,13 +3832,15 @@ void whisper_print_timings(struct whisper_context * ctx) {
38273832
const int32_t n_sample = std::max(1, ctx->state->n_sample);
38283833
const int32_t n_encode = std::max(1, ctx->state->n_encode);
38293834
const int32_t n_decode = std::max(1, ctx->state->n_decode);
3835+
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
38303836
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
38313837

38323838
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
38333839
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
38343840
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
38353841
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
38363842
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3843+
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
38373844
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
38383845
}
38393846
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
@@ -3850,6 +3857,7 @@ void whisper_reset_timings(struct whisper_context * ctx) {
38503857
ctx->state->n_sample = 0;
38513858
ctx->state->n_encode = 0;
38523859
ctx->state->n_decode = 0;
3860+
ctx->state->n_batchd = 0;
38533861
ctx->state->n_prompt = 0;
38543862
}
38553863
}
@@ -5896,11 +5904,13 @@ int whisper_full_parallel(
58965904
ctx->state->t_sample_us += states[i]->t_sample_us;
58975905
ctx->state->t_encode_us += states[i]->t_encode_us;
58985906
ctx->state->t_decode_us += states[i]->t_decode_us;
5907+
ctx->state->t_batchd_us += states[i]->t_batchd_us;
58995908
ctx->state->t_prompt_us += states[i]->t_prompt_us;
59005909

59015910
ctx->state->n_sample += states[i]->n_sample;
59025911
ctx->state->n_encode += states[i]->n_encode;
59035912
ctx->state->n_decode += states[i]->n_decode;
5913+
ctx->state->n_batchd += states[i]->n_batchd;
59045914
ctx->state->n_prompt += states[i]->n_prompt;
59055915

59065916
whisper_free_state(states[i]);

0 commit comments

Comments
 (0)