11
11
DynamicInferenceContext ,
12
12
)
13
13
from megatron .core .inference .engines import DynamicInferenceEngine
14
- from megatron .core .inference .model_inference_wrappers .gpt .gpt_inference_wrapper import GPTInferenceWrapper
14
+ from megatron .core .inference .model_inference_wrappers .gpt .gpt_inference_wrapper import (
15
+ GPTInferenceWrapper ,
16
+ )
15
17
from megatron .core .inference .sampling_params import SamplingParams
16
- from megatron .core .inference .text_generation_controllers .text_generation_controller import TextGenerationController
17
- from megatron .core .transformer .module import MegatronModule
18
- from megatron .training import (
19
- get_args ,
20
- get_model as _get_model ,
21
- get_tokenizer ,
22
- initialize_megatron ,
18
+ from megatron .core .inference .text_generation_controllers .text_generation_controller import (
19
+ TextGenerationController ,
23
20
)
21
+ from megatron .core .transformer .module import MegatronModule
22
+ from megatron .training import get_args , get_model as _get_model , get_tokenizer , initialize_megatron
24
23
from megatron .training .checkpointing import load_checkpoint
25
24
from pretrain_gpt import model_provider
26
25
@@ -33,9 +32,11 @@ def add_dynamic_inference_args(parser: ArgumentParser) -> ArgumentParser:
33
32
add_common_inference_args (parser )
34
33
35
34
group = parser .add_argument_group (title = 'Dynamic inference' )
36
- group .add_argument ("--inference-ckpt-non-strict" , action = "store_true" ,
37
- help = "Load checkpoint with `strict=False`." )
38
-
35
+ group .add_argument (
36
+ "--inference-ckpt-non-strict" ,
37
+ action = "store_true" ,
38
+ help = "Load checkpoint with `strict=False`." ,
39
+ )
39
40
40
41
return parser
41
42
@@ -68,10 +69,7 @@ def get_model() -> MegatronModule:
68
69
return model
69
70
70
71
71
- def get_inference_context (
72
- requests : List [Request ],
73
- sampling_params : SamplingParams ,
74
- ):
72
+ def get_inference_context (requests : List [Request ], sampling_params : SamplingParams ):
75
73
"""The inference context manages the KV cache and other inference state."""
76
74
77
75
args = get_args ()
@@ -86,7 +84,9 @@ def get_inference_context(
86
84
params_dtype = args .params_dtype ,
87
85
num_layers = args .num_layers ,
88
86
kv_channels = args .kv_channels ,
89
- num_attention_heads = args .num_query_groups if args .group_query_attention else args .num_attention_heads ,
87
+ num_attention_heads = (
88
+ args .num_query_groups if args .group_query_attention else args .num_attention_heads
89
+ ),
90
90
max_sequence_length = max_sequence_length ,
91
91
buffer_size_gb = args .inference_dynamic_batching_buffer_size_gb ,
92
92
buffer_guaranteed_fraction = args .inference_dynamic_batching_buffer_guaranteed_fraction ,
@@ -101,8 +101,7 @@ def get_inference_context(
101
101
102
102
103
103
def get_inference_controller (
104
- model : MegatronModule ,
105
- context : DynamicInferenceContext ,
104
+ model : MegatronModule , context : DynamicInferenceContext
106
105
) -> TextGenerationController :
107
106
"""Buid text generation controller, which manages the model inference context.
108
107
@@ -122,9 +121,9 @@ def get_inference_controller(
122
121
123
122
# Note: the following is taken from AbstractModelInferenceWrapper.prep_model_for_inference().
124
123
from megatron .core import parallel_state
124
+
125
125
model .model_is_pipeline_parallel = not (
126
- parallel_state .is_pipeline_first_stage () and
127
- parallel_state .is_pipeline_last_stage ()
126
+ parallel_state .is_pipeline_first_stage () and parallel_state .is_pipeline_last_stage ()
128
127
)
129
128
130
129
# Text generation controller.
@@ -134,9 +133,7 @@ def get_inference_controller(
134
133
135
134
136
135
def run_inference (
137
- requests : List [Request ],
138
- sampling_params : SamplingParams ,
139
- engine : DynamicInferenceEngine ,
136
+ requests : List [Request ], sampling_params : SamplingParams , engine : DynamicInferenceEngine
140
137
) -> None :
141
138
"""Add requests to engine and generate tokens.
142
139
@@ -204,24 +201,22 @@ def run_inference(
204
201
request .output_text = finished_request .generated_text
205
202
request .state = "finished"
206
203
num_requests_finished += 1
207
-
204
+
208
205
output_times .append (get_curr_time () - output_start )
209
206
210
207
# Check if all requests are finished.
211
- if not (engine .has_unfinished_requests () or
212
- num_requests_added < num_requests_total ):
208
+ if not (engine .has_unfinished_requests () or num_requests_added < num_requests_total ):
213
209
break
214
210
215
211
return step_times , add_times , output_times
216
212
217
213
218
- if __name__ == "__main__" :
219
-
214
+ @ torch . inference_mode ()
215
+ def main ():
220
216
# Initialize Megatron.
221
217
initialize_megatron (
222
218
extra_args_provider = add_dynamic_inference_args ,
223
- args_defaults = {'no_load_rng' : True ,
224
- 'no_load_optim' : True },
219
+ args_defaults = {'no_load_rng' : True , 'no_load_optim' : True },
225
220
)
226
221
227
222
args = get_args ()
@@ -243,32 +238,38 @@ def run_inference(
243
238
controller = get_inference_controller (model , context )
244
239
245
240
# Inference engine.
246
- engine = DynamicInferenceEngine (controller ,
247
- context ,
248
- termination_id = tokenizer .eod ,
249
- enable_cuda_graph = args .enable_cuda_graph ,
250
- random_seed = args .seed )
241
+ engine = DynamicInferenceEngine (
242
+ controller ,
243
+ context ,
244
+ termination_id = tokenizer .eod ,
245
+ enable_cuda_graph = args .enable_cuda_graph ,
246
+ random_seed = args .seed ,
247
+ )
251
248
252
249
# Print setup.
253
- setup_prefix = "dynamic | cg %d | %s | bf %.0f, flw %.1f [r %d, t %d], gtd %.2f [r %d] ... reqs %d" % (
254
- args .enable_cuda_graph ,
255
- (
256
- f"<user prompts, n { len (args .prompts )} >"
257
- if args .prompts else
258
- "<auto prompts> %s, %d, %.1e, %.1e" % (
259
- "(%s)" % " " .join (map (str , args .num_tokens_to_prompt )),
260
- args .num_tokens_to_generate ,
261
- args .incoming_requests_duration ,
262
- args .incoming_requests_per_sec ,
263
- )
264
- ),
265
- args .inference_dynamic_batching_buffer_size_gb ,
266
- args .inference_dynamic_batching_buffer_overflow_factor ,
267
- context .max_requests ,
268
- context .max_tokens ,
269
- args .inference_dynamic_batching_buffer_guaranteed_fraction ,
270
- context .gtd_request_count ,
271
- len (requests ),
250
+ setup_prefix = (
251
+ "dynamic | cg %d | %s | bf %.0f, flw %.1f [r %d, t %d], gtd %.2f [r %d] ... reqs %d"
252
+ % (
253
+ args .enable_cuda_graph ,
254
+ (
255
+ f"<user prompts, n { len (args .prompts )} >"
256
+ if args .prompts
257
+ else "<auto prompts> %s, %d, %.1e, %.1e"
258
+ % (
259
+ "(%s)" % " " .join (map (str , args .num_tokens_to_prompt )),
260
+ args .num_tokens_to_generate ,
261
+ args .incoming_requests_duration ,
262
+ args .incoming_requests_per_sec ,
263
+ )
264
+ ),
265
+ args .inference_dynamic_batching_buffer_size_gb ,
266
+ args .inference_dynamic_batching_buffer_overflow_factor ,
267
+ context .max_requests ,
268
+ context .max_tokens ,
269
+ args .inference_dynamic_batching_buffer_guaranteed_fraction ,
270
+ context .gtd_request_count ,
271
+ len (requests ),
272
+ )
272
273
)
273
274
print ("~~~" )
274
275
print (setup_prefix )
@@ -297,24 +298,34 @@ def run_inference(
297
298
for unique_idx , (prompt_text , request_idxs ) in enumerate (unique_prompt_map .items ()):
298
299
request_idx = request_idxs [0 ]
299
300
request = requests [request_idx ]
300
- print (f"{ unique_idx } /{ len (unique_prompt_map )} [{ len (request_idxs )} ]. { prompt_text } ... %s" % request .output_text .replace ("\n " , "\\ n" ))
301
+ print (
302
+ f"{ unique_idx } /{ len (unique_prompt_map )} [{ len (request_idxs )} ]. { prompt_text } ... %s"
303
+ % request .output_text .replace ("\n " , "\\ n" )
304
+ )
301
305
302
306
# Timing results.
303
307
stats = torch .cuda .memory_stats ()
304
308
print ("~~~" )
305
- print ("%s ... mem %.1f/%.1f ... total time: %.3f ... step time: total %.3f [ p %.3f, d %.3f ], mean [ p %.3f, d %.3f ], count [ p %d, d %d ] ... add time: %.3f, output time: %.3f." % (
306
- setup_prefix ,
307
- stats ["allocated_bytes.all.peak" ] / (1024 ** 3 ),
308
- stats ["reserved_bytes.all.peak" ] / (1024 ** 3 ),
309
- sum (step_times ["prefill" ]) + sum (step_times ["decode" ]) + sum (add_times ),
310
- sum (step_times ["prefill" ]) + sum (step_times ["decode" ]),
311
- sum (step_times ["prefill" ]),
312
- sum (step_times ["decode" ]),
313
- sum (step_times ["prefill" ]) / len (step_times ["prefill" ]),
314
- sum (step_times ["decode" ]) / len (step_times ["decode" ]),
315
- len (step_times ["prefill" ]),
316
- len (step_times ["decode" ]),
317
- sum (add_times ),
318
- sum (output_times ),
319
- ))
309
+ print (
310
+ "%s ... mem %.1f/%.1f ... total time: %.3f ... step time: total %.3f [ p %.3f, d %.3f ], mean [ p %.3f, d %.3f ], count [ p %d, d %d ] ... add time: %.3f, output time: %.3f."
311
+ % (
312
+ setup_prefix ,
313
+ stats ["allocated_bytes.all.peak" ] / (1024 ** 3 ),
314
+ stats ["reserved_bytes.all.peak" ] / (1024 ** 3 ),
315
+ sum (step_times ["prefill" ]) + sum (step_times ["decode" ]) + sum (add_times ),
316
+ sum (step_times ["prefill" ]) + sum (step_times ["decode" ]),
317
+ sum (step_times ["prefill" ]),
318
+ sum (step_times ["decode" ]),
319
+ sum (step_times ["prefill" ]) / len (step_times ["prefill" ]),
320
+ sum (step_times ["decode" ]) / len (step_times ["decode" ]),
321
+ len (step_times ["prefill" ]),
322
+ len (step_times ["decode" ]),
323
+ sum (add_times ),
324
+ sum (output_times ),
325
+ )
326
+ )
320
327
print ("~~~" )
328
+
329
+
330
+ if __name__ == "__main__" :
331
+ main ()
0 commit comments