Skip to content

Commit 64ed209

Browse files
lhpqaqngxson
andauthored
server: Add "tokens per second" information in the backend (ggml-org#10548)
* add cmake rvv support * add timings * remove space * update readme * fix * fix code * remove empty line * add test --------- Co-authored-by: Xuan Son Nguyen <[email protected]>
1 parent 991f8aa commit 64ed209

File tree

5 files changed

+44
-1
lines changed

5 files changed

+44
-1
lines changed

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ struct common_params_sampling {
133133
bool penalize_nl = false; // consider newlines as a repeatable token
134134
bool ignore_eos = false;
135135
bool no_perf = false; // disable performance metrics
136+
bool timing_per_token = false;
136137

137138
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
138139

examples/server/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,8 @@ node index.js
416416

417417
`samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]` - these are all the available values.
418418

419+
`timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false`
420+
419421
**Response format**
420422

421423
- Note: When using streaming mode (`stream`), only `content` and `stop` will be returned until end of completion.

examples/server/server.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ struct server_slot {
177177
bool stopped_word = false;
178178
bool stopped_limit = false;
179179

180+
bool timings_per_token = false;
181+
180182
bool oaicompat = false;
181183

182184
std::string oaicompat_model;
@@ -882,6 +884,8 @@ struct server_context {
882884
slot.oaicompat_model = "";
883885
}
884886

887+
slot.timings_per_token = json_value(data, "timings_per_token", false);
888+
885889
slot.params.stream = json_value(data, "stream", false);
886890
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
887891
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
@@ -1279,6 +1283,7 @@ struct server_context {
12791283
{"speculative.n_max", slot.params.speculative.n_max},
12801284
{"speculative.n_min", slot.params.speculative.n_min},
12811285
{"speculative.p_min", slot.params.speculative.p_min},
1286+
{"timings_per_token", slot.timings_per_token},
12821287
};
12831288
}
12841289

@@ -1336,6 +1341,10 @@ struct server_context {
13361341
res.data["model"] = slot.oaicompat_model;
13371342
}
13381343

1344+
if (slot.timings_per_token) {
1345+
res.data["timings"] = slot.get_formated_timings();
1346+
}
1347+
13391348
queue_results.send(res);
13401349
}
13411350

@@ -2274,12 +2283,17 @@ struct server_context {
22742283
common_sampler_accept(slot.smpl, id, true);
22752284

22762285
slot.n_decoded += 1;
2286+
2287+
const int64_t t_current = ggml_time_us();
2288+
22772289
if (slot.n_decoded == 1) {
2278-
slot.t_start_generation = ggml_time_us();
2290+
slot.t_start_generation = t_current;
22792291
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
22802292
metrics.on_prompt_eval(slot);
22812293
}
22822294

2295+
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
2296+
22832297
completion_token_output result;
22842298
result.tok = id;
22852299

examples/server/tests/unit/test_chat_completion.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,20 @@ def test_invalid_chat_completion_req(messages):
146146
})
147147
assert res.status_code == 400 or res.status_code == 500
148148
assert "error" in res.body
149+
150+
151+
def test_chat_completion_with_timings_per_token():
152+
global server
153+
server.start()
154+
res = server.make_stream_request("POST", "/chat/completions", data={
155+
"max_tokens": 10,
156+
"messages": [{"role": "user", "content": "test"}],
157+
"stream": True,
158+
"timings_per_token": True,
159+
})
160+
for data in res:
161+
assert "timings" in data
162+
assert "prompt_per_second" in data["timings"]
163+
assert "predicted_per_second" in data["timings"]
164+
assert "predicted_n" in data["timings"]
165+
assert data["timings"]["predicted_n"] <= 10

examples/server/utils.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,10 @@ static json format_final_response_oaicompat(const json & request, const json & r
650650
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
651651
}
652652

653+
if (result.contains("timings")) {
654+
res.push_back({"timings", json_value(result, "timings", json::object())});
655+
}
656+
653657
return res;
654658
}
655659

@@ -740,6 +744,11 @@ static std::vector<json> format_partial_response_oaicompat(const json & result,
740744
{"model", modelname},
741745
{"object", "chat.completion.chunk"}
742746
};
747+
748+
if (result.contains("timings")) {
749+
ret.push_back({"timings", json_value(result, "timings", json::object())});
750+
}
751+
743752
if (!finish_reason.empty()) {
744753
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
745754
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);

0 commit comments

Comments
 (0)