25
25
from sglang .srt .entrypoints .http_server import launch_server
26
26
from sglang .srt .server_args import ServerArgs
27
27
from sglang .srt .utils import kill_process_tree
28
+ from sglang .test .test_utils import is_in_ci , write_github_step_summary
28
29
29
30
30
31
@dataclasses .dataclass
@@ -33,9 +34,13 @@ class BenchArgs:
33
34
batch_size : Tuple [int ] = (1 ,)
34
35
input_len : Tuple [int ] = (1024 ,)
35
36
output_len : Tuple [int ] = (16 ,)
37
+ temperature : float = 0.0
38
+ return_logprob : bool = False
39
+ input_len_step_percentage : float = 0.0
36
40
result_filename : str = "result.jsonl"
37
41
base_url : str = ""
38
42
skip_warmup : bool = False
43
+ show_report : bool = False
39
44
40
45
@staticmethod
41
46
def add_cli_args (parser : argparse .ArgumentParser ):
@@ -49,11 +54,19 @@ def add_cli_args(parser: argparse.ArgumentParser):
49
54
parser .add_argument (
50
55
"--output-len" , type = int , nargs = "+" , default = BenchArgs .output_len
51
56
)
57
+ parser .add_argument ("--temperature" , type = float , default = BenchArgs .temperature )
58
+ parser .add_argument ("--return-logprob" , action = "store_true" )
59
+ parser .add_argument (
60
+ "--input-len-step-percentage" ,
61
+ type = float ,
62
+ default = BenchArgs .input_len_step_percentage ,
63
+ )
52
64
parser .add_argument (
53
65
"--result-filename" , type = str , default = BenchArgs .result_filename
54
66
)
55
67
parser .add_argument ("--base-url" , type = str , default = BenchArgs .base_url )
56
68
parser .add_argument ("--skip-warmup" , action = "store_true" )
69
+ parser .add_argument ("--show-report" , action = "store_true" )
57
70
58
71
@classmethod
59
72
def from_cli_args (cls , args : argparse .Namespace ):
@@ -99,36 +112,89 @@ def run_one_case(
99
112
batch_size : int ,
100
113
input_len : int ,
101
114
output_len : int ,
115
+ temperature : float ,
116
+ return_logprob : bool ,
117
+ input_len_step_percentage : float ,
102
118
run_name : str ,
103
119
result_filename : str ,
104
120
):
121
+ requests .post (url + "/flush_cache" )
122
+ input_lens = [
123
+ int (input_len * (1 + (i - (batch_size - 1 ) / 2 ) * input_len_step_percentage ))
124
+ for i in range (batch_size )
125
+ ]
105
126
input_ids = [
106
- [int (x ) for x in np .random .randint (0 , high = 16384 , size = (input_len ,))]
107
- for _ in range (batch_size )
127
+ [int (x ) for x in np .random .randint (0 , high = 16384 , size = (input_lens [ i ] ,))]
128
+ for i in range (batch_size )
108
129
]
109
130
131
+ use_structured_outputs = False
132
+ if use_structured_outputs :
133
+ texts = []
134
+ for _ in range (batch_size ):
135
+ texts .append (
136
+ "Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n "
137
+ * 50
138
+ + "Assistant:"
139
+ )
140
+ json_schema = "$$ANY$$"
141
+ else :
142
+ json_schema = None
143
+
110
144
tic = time .time ()
111
145
response = requests .post (
112
146
url + "/generate" ,
113
147
json = {
148
+ # "text": texts,
114
149
"input_ids" : input_ids ,
115
150
"sampling_params" : {
116
- "temperature" : 0 ,
151
+ "temperature" : temperature ,
117
152
"max_new_tokens" : output_len ,
118
153
"ignore_eos" : True ,
154
+ "json_schema" : json_schema ,
119
155
},
156
+ "return_logprob" : return_logprob ,
157
+ "stream" : True ,
120
158
},
159
+ stream = True ,
121
160
)
122
- latency = time .time () - tic
123
161
124
- _ = response .json ()
125
- output_throughput = batch_size * output_len / latency
162
+ # The TTFT of the last request in the batch
163
+ ttft = 0.0
164
+ for chunk in response .iter_lines (decode_unicode = False ):
165
+ chunk = chunk .decode ("utf-8" )
166
+ if chunk and chunk .startswith ("data:" ):
167
+ if chunk == "data: [DONE]" :
168
+ break
169
+ data = json .loads (chunk [5 :].strip ("\n " ))
170
+ if "error" in data :
171
+ raise RuntimeError (f"Request has failed. { data } ." )
172
+
173
+ assert (
174
+ data ["meta_info" ]["finish_reason" ] is None
175
+ or data ["meta_info" ]["finish_reason" ]["type" ] == "length"
176
+ )
177
+ if data ["meta_info" ]["completion_tokens" ] == 1 :
178
+ ttft = time .time () - tic
179
+
180
+ latency = time .time () - tic
181
+ input_throughput = batch_size * input_len / ttft
182
+ output_throughput = batch_size * output_len / (latency - ttft )
126
183
overall_throughput = batch_size * (input_len + output_len ) / latency
127
184
185
+ server_info = requests .get (url + "/get_server_info" ).json ()
186
+ acc_length = server_info ["internal_states" ][0 ].get ("avg_spec_accept_length" , None )
187
+ last_gen_throughput = server_info ["internal_states" ][0 ]["last_gen_throughput" ]
188
+
128
189
print (f"batch size: { batch_size } " )
190
+ print (f"input_len: { input_len } " )
191
+ print (f"output_len: { output_len } " )
129
192
print (f"latency: { latency :.2f} s" )
130
- print (f"output throughput: { output_throughput :.2f} token/s" )
131
- print (f"(input + output) throughput: { overall_throughput :.2f} token/s" )
193
+ print (f"ttft: { ttft :.2f} s" )
194
+ print (f"Last generation throughput: { last_gen_throughput :.2f} tok/s" )
195
+ print (f"Input throughput: { input_throughput :.2f} tok/s" )
196
+ if output_len != 1 :
197
+ print (f"output throughput: { output_throughput :.2f} tok/s" )
132
198
133
199
if result_filename :
134
200
with open (result_filename , "a" ) as fout :
@@ -140,9 +206,21 @@ def run_one_case(
140
206
"latency" : round (latency , 4 ),
141
207
"output_throughput" : round (output_throughput , 2 ),
142
208
"overall_throughput" : round (overall_throughput , 2 ),
209
+ "last_gen_throughput" : round (last_gen_throughput , 2 ),
143
210
}
144
211
fout .write (json .dumps (res ) + "\n " )
145
212
213
+ return (
214
+ batch_size ,
215
+ latency ,
216
+ ttft ,
217
+ input_throughput ,
218
+ output_throughput ,
219
+ overall_throughput ,
220
+ last_gen_throughput ,
221
+ acc_length ,
222
+ )
223
+
146
224
147
225
def run_benchmark (server_args : ServerArgs , bench_args : BenchArgs ):
148
226
if bench_args .base_url :
@@ -152,34 +230,84 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
152
230
153
231
# warmup
154
232
if not bench_args .skip_warmup :
233
+ print ("=" * 8 + " Warmup Begin " + "=" * 8 )
155
234
run_one_case (
156
235
base_url ,
157
236
batch_size = 16 ,
158
237
input_len = 1024 ,
159
238
output_len = 16 ,
239
+ temperature = bench_args .temperature ,
240
+ return_logprob = bench_args .return_logprob ,
241
+ input_len_step_percentage = bench_args .input_len_step_percentage ,
160
242
run_name = "" ,
161
243
result_filename = "" ,
162
244
)
245
+ print ("=" * 8 + " Warmup End " + "=" * 8 + "\n " )
163
246
164
247
# benchmark
248
+ result = []
165
249
try :
166
250
for bs , il , ol in itertools .product (
167
251
bench_args .batch_size , bench_args .input_len , bench_args .output_len
168
252
):
169
- run_one_case (
170
- base_url ,
171
- bs ,
172
- il ,
173
- ol ,
174
- bench_args .run_name ,
175
- bench_args .result_filename ,
253
+ result .append (
254
+ run_one_case (
255
+ base_url ,
256
+ bs ,
257
+ il ,
258
+ ol ,
259
+ temperature = bench_args .temperature ,
260
+ return_logprob = bench_args .return_logprob ,
261
+ input_len_step_percentage = bench_args .input_len_step_percentage ,
262
+ run_name = bench_args .run_name ,
263
+ result_filename = bench_args .result_filename ,
264
+ )
176
265
)
177
266
finally :
178
267
if proc :
179
268
kill_process_tree (proc .pid )
180
269
181
270
print (f"\n Results are saved to { bench_args .result_filename } " )
182
271
272
+ if not bench_args .show_report :
273
+ return
274
+
275
+ summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n "
276
+ summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n "
277
+
278
+ for (
279
+ batch_size ,
280
+ latency ,
281
+ ttft ,
282
+ input_throughput ,
283
+ output_throughput ,
284
+ overall_throughput ,
285
+ last_gen_throughput ,
286
+ acc_length ,
287
+ ) in result :
288
+ hourly_cost = 2 * server_args .tp_size # $2/hour for one H100
289
+ input_util = 0.7
290
+ accept_length = round (acc_length , 2 ) if acc_length is not None else "n/a"
291
+ line = (
292
+ f"| { batch_size } | "
293
+ f"{ latency :.2f} | "
294
+ f"{ input_throughput :.2f} | "
295
+ f"{ output_throughput :.2f} | "
296
+ f"{ accept_length } | "
297
+ f"{ 1 / (output_throughput / batch_size ) * 1000 :.2f} | "
298
+ f"{ 1e6 / (input_throughput * input_util ) / 3600 * hourly_cost :.2f} | "
299
+ f"{ 1e6 / output_throughput / 3600 * hourly_cost :.2f} |\n "
300
+ )
301
+ summary += line
302
+
303
+ # print metrics table
304
+ print (summary )
305
+
306
+ if is_in_ci ():
307
+ write_github_step_summary (
308
+ f"### Test Nightly Benchmark (bench_one_batch) \n { summary } "
309
+ )
310
+
183
311
184
312
if __name__ == "__main__" :
185
313
parser = argparse .ArgumentParser ()
0 commit comments