Skip to content

Commit 74adbed

Browse files
merrymercyrkooo567
authored andcommitted
Log if cuda graph is used & extend cuda graph capture to cuda-graph-max-bs (sgl-project#6201)
Co-authored-by: SangBin Cho <[email protected]>
1 parent 94682fb commit 74adbed

27 files changed

+292
-120
lines changed

python/sglang/bench_offline_throughput.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,9 @@ def throughput_test_once(
259259
measurement_results["total_input_tokens"]
260260
+ measurement_results["total_output_tokens"]
261261
) / latency
262-
measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
262+
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
263+
"last_gen_throughput"
264+
]
263265

264266
return measurement_results
265267

python/sglang/bench_one_batch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def extend(reqs, model_runner):
246246
_maybe_prepare_dp_attn_batch(batch, model_runner)
247247
model_worker_batch = batch.get_model_worker_batch()
248248
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
249-
logits_output = model_runner.forward(forward_batch)
249+
logits_output, _ = model_runner.forward(forward_batch)
250250
next_token_ids = model_runner.sample(logits_output, forward_batch)
251251
return next_token_ids, logits_output.next_token_logits, batch
252252

@@ -258,7 +258,7 @@ def decode(input_token_ids, batch, model_runner):
258258
_maybe_prepare_dp_attn_batch(batch, model_runner)
259259
model_worker_batch = batch.get_model_worker_batch()
260260
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
261-
logits_output = model_runner.forward(forward_batch)
261+
logits_output, _ = model_runner.forward(forward_batch)
262262
next_token_ids = model_runner.sample(logits_output, forward_batch)
263263
return next_token_ids, logits_output.next_token_logits
264264

python/sglang/bench_one_batch_server.py

Lines changed: 143 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sglang.srt.entrypoints.http_server import launch_server
2626
from sglang.srt.server_args import ServerArgs
2727
from sglang.srt.utils import kill_process_tree
28+
from sglang.test.test_utils import is_in_ci, write_github_step_summary
2829

2930

3031
@dataclasses.dataclass
@@ -33,9 +34,13 @@ class BenchArgs:
3334
batch_size: Tuple[int] = (1,)
3435
input_len: Tuple[int] = (1024,)
3536
output_len: Tuple[int] = (16,)
37+
temperature: float = 0.0
38+
return_logprob: bool = False
39+
input_len_step_percentage: float = 0.0
3640
result_filename: str = "result.jsonl"
3741
base_url: str = ""
3842
skip_warmup: bool = False
43+
show_report: bool = False
3944

4045
@staticmethod
4146
def add_cli_args(parser: argparse.ArgumentParser):
@@ -49,11 +54,19 @@ def add_cli_args(parser: argparse.ArgumentParser):
4954
parser.add_argument(
5055
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
5156
)
57+
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
58+
parser.add_argument("--return-logprob", action="store_true")
59+
parser.add_argument(
60+
"--input-len-step-percentage",
61+
type=float,
62+
default=BenchArgs.input_len_step_percentage,
63+
)
5264
parser.add_argument(
5365
"--result-filename", type=str, default=BenchArgs.result_filename
5466
)
5567
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
5668
parser.add_argument("--skip-warmup", action="store_true")
69+
parser.add_argument("--show-report", action="store_true")
5770

5871
@classmethod
5972
def from_cli_args(cls, args: argparse.Namespace):
@@ -99,36 +112,89 @@ def run_one_case(
99112
batch_size: int,
100113
input_len: int,
101114
output_len: int,
115+
temperature: float,
116+
return_logprob: bool,
117+
input_len_step_percentage: float,
102118
run_name: str,
103119
result_filename: str,
104120
):
121+
requests.post(url + "/flush_cache")
122+
input_lens = [
123+
int(input_len * (1 + (i - (batch_size - 1) / 2) * input_len_step_percentage))
124+
for i in range(batch_size)
125+
]
105126
input_ids = [
106-
[int(x) for x in np.random.randint(0, high=16384, size=(input_len,))]
107-
for _ in range(batch_size)
127+
[int(x) for x in np.random.randint(0, high=16384, size=(input_lens[i],))]
128+
for i in range(batch_size)
108129
]
109130

131+
use_structured_outputs = False
132+
if use_structured_outputs:
133+
texts = []
134+
for _ in range(batch_size):
135+
texts.append(
136+
"Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
137+
* 50
138+
+ "Assistant:"
139+
)
140+
json_schema = "$$ANY$$"
141+
else:
142+
json_schema = None
143+
110144
tic = time.time()
111145
response = requests.post(
112146
url + "/generate",
113147
json={
148+
# "text": texts,
114149
"input_ids": input_ids,
115150
"sampling_params": {
116-
"temperature": 0,
151+
"temperature": temperature,
117152
"max_new_tokens": output_len,
118153
"ignore_eos": True,
154+
"json_schema": json_schema,
119155
},
156+
"return_logprob": return_logprob,
157+
"stream": True,
120158
},
159+
stream=True,
121160
)
122-
latency = time.time() - tic
123161

124-
_ = response.json()
125-
output_throughput = batch_size * output_len / latency
162+
# The TTFT of the last request in the batch
163+
ttft = 0.0
164+
for chunk in response.iter_lines(decode_unicode=False):
165+
chunk = chunk.decode("utf-8")
166+
if chunk and chunk.startswith("data:"):
167+
if chunk == "data: [DONE]":
168+
break
169+
data = json.loads(chunk[5:].strip("\n"))
170+
if "error" in data:
171+
raise RuntimeError(f"Request has failed. {data}.")
172+
173+
assert (
174+
data["meta_info"]["finish_reason"] is None
175+
or data["meta_info"]["finish_reason"]["type"] == "length"
176+
)
177+
if data["meta_info"]["completion_tokens"] == 1:
178+
ttft = time.time() - tic
179+
180+
latency = time.time() - tic
181+
input_throughput = batch_size * input_len / ttft
182+
output_throughput = batch_size * output_len / (latency - ttft)
126183
overall_throughput = batch_size * (input_len + output_len) / latency
127184

185+
server_info = requests.get(url + "/get_server_info").json()
186+
acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
187+
last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]
188+
128189
print(f"batch size: {batch_size}")
190+
print(f"input_len: {input_len}")
191+
print(f"output_len: {output_len}")
129192
print(f"latency: {latency:.2f} s")
130-
print(f"output throughput: {output_throughput:.2f} token/s")
131-
print(f"(input + output) throughput: {overall_throughput:.2f} token/s")
193+
print(f"ttft: {ttft:.2f} s")
194+
print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s")
195+
print(f"Input throughput: {input_throughput:.2f} tok/s")
196+
if output_len != 1:
197+
print(f"output throughput: {output_throughput:.2f} tok/s")
132198

133199
if result_filename:
134200
with open(result_filename, "a") as fout:
@@ -140,9 +206,21 @@ def run_one_case(
140206
"latency": round(latency, 4),
141207
"output_throughput": round(output_throughput, 2),
142208
"overall_throughput": round(overall_throughput, 2),
209+
"last_gen_throughput": round(last_gen_throughput, 2),
143210
}
144211
fout.write(json.dumps(res) + "\n")
145212

213+
return (
214+
batch_size,
215+
latency,
216+
ttft,
217+
input_throughput,
218+
output_throughput,
219+
overall_throughput,
220+
last_gen_throughput,
221+
acc_length,
222+
)
223+
146224

147225
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
148226
if bench_args.base_url:
@@ -152,34 +230,84 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
152230

153231
# warmup
154232
if not bench_args.skip_warmup:
233+
print("=" * 8 + " Warmup Begin " + "=" * 8)
155234
run_one_case(
156235
base_url,
157236
batch_size=16,
158237
input_len=1024,
159238
output_len=16,
239+
temperature=bench_args.temperature,
240+
return_logprob=bench_args.return_logprob,
241+
input_len_step_percentage=bench_args.input_len_step_percentage,
160242
run_name="",
161243
result_filename="",
162244
)
245+
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
163246

164247
# benchmark
248+
result = []
165249
try:
166250
for bs, il, ol in itertools.product(
167251
bench_args.batch_size, bench_args.input_len, bench_args.output_len
168252
):
169-
run_one_case(
170-
base_url,
171-
bs,
172-
il,
173-
ol,
174-
bench_args.run_name,
175-
bench_args.result_filename,
253+
result.append(
254+
run_one_case(
255+
base_url,
256+
bs,
257+
il,
258+
ol,
259+
temperature=bench_args.temperature,
260+
return_logprob=bench_args.return_logprob,
261+
input_len_step_percentage=bench_args.input_len_step_percentage,
262+
run_name=bench_args.run_name,
263+
result_filename=bench_args.result_filename,
264+
)
176265
)
177266
finally:
178267
if proc:
179268
kill_process_tree(proc.pid)
180269

181270
print(f"\nResults are saved to {bench_args.result_filename}")
182271

272+
if not bench_args.show_report:
273+
return
274+
275+
summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n"
276+
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n"
277+
278+
for (
279+
batch_size,
280+
latency,
281+
ttft,
282+
input_throughput,
283+
output_throughput,
284+
overall_throughput,
285+
last_gen_throughput,
286+
acc_length,
287+
) in result:
288+
hourly_cost = 2 * server_args.tp_size # $2/hour for one H100
289+
input_util = 0.7
290+
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
291+
line = (
292+
f"| {batch_size} | "
293+
f"{latency:.2f} | "
294+
f"{input_throughput:.2f} | "
295+
f"{output_throughput:.2f} | "
296+
f"{accept_length} | "
297+
f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
298+
f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
299+
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n"
300+
)
301+
summary += line
302+
303+
# print metrics table
304+
print(summary)
305+
306+
if is_in_ci():
307+
write_github_step_summary(
308+
f"### Test Nightly Benchmark (bench_one_batch) \n{summary}"
309+
)
310+
183311

184312
if __name__ == "__main__":
185313
parser = argparse.ArgumentParser()

python/sglang/bench_serving.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,7 @@ async def benchmark(
11031103
lora_names: List[str],
11041104
extra_request_body: Dict[str, Any],
11051105
profile: bool,
1106-
pd_seperated: bool = False,
1106+
pd_separated: bool = False,
11071107
flush_cache: bool = False,
11081108
warmup_requests: int = 1,
11091109
):
@@ -1239,12 +1239,14 @@ async def limited_request_func(request_func_input, pbar):
12391239

12401240
if "sglang" in backend:
12411241
server_info = requests.get(base_url + "/get_server_info")
1242-
if pd_seperated:
1243-
accept_length = server_info.json()["decode"][0].get(
1242+
if pd_separated:
1243+
accept_length = server_info.json()["decode"][0]["internal_states"][0].get(
12441244
"avg_spec_accept_length", None
12451245
)
12461246
else:
1247-
accept_length = server_info.json().get("avg_spec_accept_length", None)
1247+
accept_length = server_info.json()["internal_states"][0].get(
1248+
"avg_spec_accept_length", None
1249+
)
12481250
else:
12491251
accept_length = None
12501252

@@ -1541,7 +1543,7 @@ def run_benchmark(args_: argparse.Namespace):
15411543
lora_names=args.lora_name,
15421544
extra_request_body=extra_request_body,
15431545
profile=args.profile,
1544-
pd_seperated=args.pd_seperated,
1546+
pd_separated=args.pd_separated,
15451547
flush_cache=args.flush_cache,
15461548
)
15471549
)

python/sglang/srt/constrained/base_grammar_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ def accept_token(self, token: int) -> None:
3737
"""
3838
raise NotImplementedError()
3939

40+
def rollback(self, k: int):
41+
raise NotImplementedError()
42+
43+
def is_terminated(self):
44+
raise NotImplementedError()
45+
4046
def allocate_vocab_mask(
4147
self, vocab_size: int, batch_size: int, device
4248
) -> torch.Tensor:

python/sglang/srt/disaggregation/prefill.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,19 +277,17 @@ def process_batch_result_disagg_prefill(
277277
next_token_ids,
278278
extend_input_len_per_req,
279279
extend_logprob_start_len_per_req,
280-
bid,
281280
) = (
282281
result.logits_output,
283282
result.next_token_ids,
284283
result.extend_input_len_per_req,
285284
result.extend_logprob_start_len_per_req,
286-
result.bid,
287285
)
288286

289287
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
290288
if self.enable_overlap:
291289
# wait
292-
_, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
290+
_, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done)
293291
else:
294292
next_token_ids = result.next_token_ids.tolist()
295293

python/sglang/srt/entrypoints/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def get_server_info(self):
330330
return {
331331
**dataclasses.asdict(self.tokenizer_manager.server_args),
332332
**self.scheduler_info,
333-
**internal_states,
333+
"internal_states": internal_states,
334334
"version": __version__,
335335
}
336336

python/sglang/srt/entrypoints/http_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ async def get_server_info():
222222
return {
223223
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
224224
**_global_state.scheduler_info,
225-
**internal_states,
225+
"internal_states": internal_states,
226226
"version": __version__,
227227
}
228228

python/sglang/srt/layers/attention/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def create_flashinfer_kv_indices_triton(
2828

2929
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
3030
for i in range(num_loop):
31-
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
31+
# index into req_to_token_ptr needs to be int64
32+
offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
3233
mask = offset < kv_end - kv_start
3334
data = tl.load(
3435
req_to_token_ptr
@@ -70,8 +71,9 @@ def create_flashmla_kv_indices_triton(
7071
num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
7172

7273
for i in range(num_pages_loop):
74+
# index into req_to_token_ptr needs to be int64
7375
paged_offset = (
74-
tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
76+
tl.arange(0, NUM_PAGE_PER_BLOCK).to(tl.int64) + i * NUM_PAGE_PER_BLOCK
7577
) * PAGED_SIZE
7678
paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
7779

0 commit comments

Comments
 (0)