Skip to content

Commit 6a6cd47

Browse files
committed
Merge branch 'fp8_inference' into 'main'
Pad input tensors and enable fp8 weights for fp8 inference See merge request ADLR/megatron-lm!3341
2 parents 0600a3c + a002d50 commit 6a6cd47

File tree

9 files changed

+216
-103
lines changed

9 files changed

+216
-103
lines changed

examples/inference/gpt/gpt_dynamic_inference.py

Lines changed: 81 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,15 @@
1111
DynamicInferenceContext,
1212
)
1313
from megatron.core.inference.engines import DynamicInferenceEngine
14-
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
14+
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
15+
GPTInferenceWrapper,
16+
)
1517
from megatron.core.inference.sampling_params import SamplingParams
16-
from megatron.core.inference.text_generation_controllers.text_generation_controller import TextGenerationController
17-
from megatron.core.transformer.module import MegatronModule
18-
from megatron.training import (
19-
get_args,
20-
get_model as _get_model,
21-
get_tokenizer,
22-
initialize_megatron,
18+
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
19+
TextGenerationController,
2320
)
21+
from megatron.core.transformer.module import MegatronModule
22+
from megatron.training import get_args, get_model as _get_model, get_tokenizer, initialize_megatron
2423
from megatron.training.checkpointing import load_checkpoint
2524
from pretrain_gpt import model_provider
2625

@@ -33,9 +32,11 @@ def add_dynamic_inference_args(parser: ArgumentParser) -> ArgumentParser:
3332
add_common_inference_args(parser)
3433

3534
group = parser.add_argument_group(title='Dynamic inference')
36-
group.add_argument("--inference-ckpt-non-strict", action="store_true",
37-
help="Load checkpoint with `strict=False`.")
38-
35+
group.add_argument(
36+
"--inference-ckpt-non-strict",
37+
action="store_true",
38+
help="Load checkpoint with `strict=False`.",
39+
)
3940

4041
return parser
4142

@@ -68,10 +69,7 @@ def get_model() -> MegatronModule:
6869
return model
6970

7071

71-
def get_inference_context(
72-
requests: List[Request],
73-
sampling_params: SamplingParams,
74-
):
72+
def get_inference_context(requests: List[Request], sampling_params: SamplingParams):
7573
"""The inference context manages the KV cache and other inference state."""
7674

7775
args = get_args()
@@ -86,7 +84,9 @@ def get_inference_context(
8684
params_dtype=args.params_dtype,
8785
num_layers=args.num_layers,
8886
kv_channels=args.kv_channels,
89-
num_attention_heads=args.num_query_groups if args.group_query_attention else args.num_attention_heads,
87+
num_attention_heads=(
88+
args.num_query_groups if args.group_query_attention else args.num_attention_heads
89+
),
9090
max_sequence_length=max_sequence_length,
9191
buffer_size_gb=args.inference_dynamic_batching_buffer_size_gb,
9292
buffer_guaranteed_fraction=args.inference_dynamic_batching_buffer_guaranteed_fraction,
@@ -101,8 +101,7 @@ def get_inference_context(
101101

102102

103103
def get_inference_controller(
104-
model: MegatronModule,
105-
context: DynamicInferenceContext,
104+
model: MegatronModule, context: DynamicInferenceContext
106105
) -> TextGenerationController:
107106
"""Buid text generation controller, which manages the model inference context.
108107
@@ -122,9 +121,9 @@ def get_inference_controller(
122121

123122
# Note: the following is taken from AbstractModelInferenceWrapper.prep_model_for_inference().
124123
from megatron.core import parallel_state
124+
125125
model.model_is_pipeline_parallel = not (
126-
parallel_state.is_pipeline_first_stage() and
127-
parallel_state.is_pipeline_last_stage()
126+
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
128127
)
129128

130129
# Text generation controller.
@@ -134,9 +133,7 @@ def get_inference_controller(
134133

135134

136135
def run_inference(
137-
requests: List[Request],
138-
sampling_params: SamplingParams,
139-
engine: DynamicInferenceEngine,
136+
requests: List[Request], sampling_params: SamplingParams, engine: DynamicInferenceEngine
140137
) -> None:
141138
"""Add requests to engine and generate tokens.
142139
@@ -204,24 +201,22 @@ def run_inference(
204201
request.output_text = finished_request.generated_text
205202
request.state = "finished"
206203
num_requests_finished += 1
207-
204+
208205
output_times.append(get_curr_time() - output_start)
209206

210207
# Check if all requests are finished.
211-
if not (engine.has_unfinished_requests() or
212-
num_requests_added < num_requests_total):
208+
if not (engine.has_unfinished_requests() or num_requests_added < num_requests_total):
213209
break
214210

215211
return step_times, add_times, output_times
216212

217213

218-
if __name__ == "__main__":
219-
214+
@torch.inference_mode()
215+
def main():
220216
# Initialize Megatron.
221217
initialize_megatron(
222218
extra_args_provider=add_dynamic_inference_args,
223-
args_defaults={'no_load_rng': True,
224-
'no_load_optim': True},
219+
args_defaults={'no_load_rng': True, 'no_load_optim': True},
225220
)
226221

227222
args = get_args()
@@ -243,32 +238,38 @@ def run_inference(
243238
controller = get_inference_controller(model, context)
244239

245240
# Inference engine.
246-
engine = DynamicInferenceEngine(controller,
247-
context,
248-
termination_id=tokenizer.eod,
249-
enable_cuda_graph=args.enable_cuda_graph,
250-
random_seed=args.seed)
241+
engine = DynamicInferenceEngine(
242+
controller,
243+
context,
244+
termination_id=tokenizer.eod,
245+
enable_cuda_graph=args.enable_cuda_graph,
246+
random_seed=args.seed,
247+
)
251248

252249
# Print setup.
253-
setup_prefix = "dynamic | cg %d | %s | bf %.0f, flw %.1f [r %d, t %d], gtd %.2f [r %d] ... reqs %d" % (
254-
args.enable_cuda_graph,
255-
(
256-
f"<user prompts, n {len(args.prompts)}>"
257-
if args.prompts else
258-
"<auto prompts> %s, %d, %.1e, %.1e" % (
259-
"(%s)" % " ".join(map(str, args.num_tokens_to_prompt)),
260-
args.num_tokens_to_generate,
261-
args.incoming_requests_duration,
262-
args.incoming_requests_per_sec,
263-
)
264-
),
265-
args.inference_dynamic_batching_buffer_size_gb,
266-
args.inference_dynamic_batching_buffer_overflow_factor,
267-
context.max_requests,
268-
context.max_tokens,
269-
args.inference_dynamic_batching_buffer_guaranteed_fraction,
270-
context.gtd_request_count,
271-
len(requests),
250+
setup_prefix = (
251+
"dynamic | cg %d | %s | bf %.0f, flw %.1f [r %d, t %d], gtd %.2f [r %d] ... reqs %d"
252+
% (
253+
args.enable_cuda_graph,
254+
(
255+
f"<user prompts, n {len(args.prompts)}>"
256+
if args.prompts
257+
else "<auto prompts> %s, %d, %.1e, %.1e"
258+
% (
259+
"(%s)" % " ".join(map(str, args.num_tokens_to_prompt)),
260+
args.num_tokens_to_generate,
261+
args.incoming_requests_duration,
262+
args.incoming_requests_per_sec,
263+
)
264+
),
265+
args.inference_dynamic_batching_buffer_size_gb,
266+
args.inference_dynamic_batching_buffer_overflow_factor,
267+
context.max_requests,
268+
context.max_tokens,
269+
args.inference_dynamic_batching_buffer_guaranteed_fraction,
270+
context.gtd_request_count,
271+
len(requests),
272+
)
272273
)
273274
print("~~~")
274275
print(setup_prefix)
@@ -297,24 +298,34 @@ def run_inference(
297298
for unique_idx, (prompt_text, request_idxs) in enumerate(unique_prompt_map.items()):
298299
request_idx = request_idxs[0]
299300
request = requests[request_idx]
300-
print(f"{unique_idx}/{len(unique_prompt_map)} [{len(request_idxs)}]. {prompt_text} ... %s" % request.output_text.replace("\n", "\\n"))
301+
print(
302+
f"{unique_idx}/{len(unique_prompt_map)} [{len(request_idxs)}]. {prompt_text} ... %s"
303+
% request.output_text.replace("\n", "\\n")
304+
)
301305

302306
# Timing results.
303307
stats = torch.cuda.memory_stats()
304308
print("~~~")
305-
print("%s ... mem %.1f/%.1f ... total time: %.3f ... step time: total %.3f [ p %.3f, d %.3f ], mean [ p %.3f, d %.3f ], count [ p %d, d %d ] ... add time: %.3f, output time: %.3f." % (
306-
setup_prefix,
307-
stats["allocated_bytes.all.peak"] / (1024**3),
308-
stats["reserved_bytes.all.peak"] / (1024**3),
309-
sum(step_times["prefill"]) + sum(step_times["decode"]) + sum(add_times),
310-
sum(step_times["prefill"]) + sum(step_times["decode"]),
311-
sum(step_times["prefill"]),
312-
sum(step_times["decode"]),
313-
sum(step_times["prefill"]) / len(step_times["prefill"]),
314-
sum(step_times["decode"]) / len(step_times["decode"]),
315-
len(step_times["prefill"]),
316-
len(step_times["decode"]),
317-
sum(add_times),
318-
sum(output_times),
319-
))
309+
print(
310+
"%s ... mem %.1f/%.1f ... total time: %.3f ... step time: total %.3f [ p %.3f, d %.3f ], mean [ p %.3f, d %.3f ], count [ p %d, d %d ] ... add time: %.3f, output time: %.3f."
311+
% (
312+
setup_prefix,
313+
stats["allocated_bytes.all.peak"] / (1024**3),
314+
stats["reserved_bytes.all.peak"] / (1024**3),
315+
sum(step_times["prefill"]) + sum(step_times["decode"]) + sum(add_times),
316+
sum(step_times["prefill"]) + sum(step_times["decode"]),
317+
sum(step_times["prefill"]),
318+
sum(step_times["decode"]),
319+
sum(step_times["prefill"]) / len(step_times["prefill"]),
320+
sum(step_times["decode"]) / len(step_times["decode"]),
321+
len(step_times["prefill"]),
322+
len(step_times["decode"]),
323+
sum(add_times),
324+
sum(output_times),
325+
)
326+
)
320327
print("~~~")
328+
329+
330+
if __name__ == "__main__":
331+
main()

examples/inference/gpt/gpt_static_inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ async def collect_stream(prompt, request_id, stream_generator):
131131
return results
132132

133133

134+
@torch.inference_mode()
134135
def main():
135136
"""Main program."""
136137

megatron/core/fp8_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool
385385
We return nullcontext() when: a) not using fp8 to train, b) layer_no is a layer
386386
that needs to be trained in bf16.
387387
"""
388+
388389
num_bf16_layers_at_start = (
389390
config.num_layers_at_start_in_bf16 if config.first_last_layers_bf16 else 0
390391
)
@@ -478,7 +479,7 @@ def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool
478479
if "preserve_high_precision_init_val" in (
479480
inspect.signature(transformer_engine.pytorch.fp8_model_init).parameters
480481
):
481-
context_args["preserve_high_precision_init_val"] = True
482+
context_args["preserve_high_precision_init_val"] = torch.is_grad_enabled()
482483
fp8_context = transformer_engine.pytorch.fp8_model_init(**context_args)
483484

484485
# First / last layer in bf16 isn't supported with delayed scaling since it

0 commit comments

Comments
 (0)