@@ -2686,10 +2686,10 @@ def _train_loop(self) -> None:
2686
2686
def _eval_train_metrics (self , device_batch ):
2687
2687
assert self ._train_data_spec is not None , 'The train data spec should be set on __init__ or fit()'
2688
2688
assert self .state .train_metrics is not None , 'The train metrics should be set on __init__ or fit()'
2689
-
2689
+ precision = self . state . precision if self . state . precision is Precision . AMP_FP8 else Precision . AMP_BF16
2690
2690
with torch .no_grad (),\
2691
2691
model_eval_mode (self .state .model ),\
2692
- _get_precision_context (self . state . precision , self .state .precision_config , self .state .deepspeed_enabled ):
2692
+ _get_precision_context (precision , self .state .precision_config , self .state .deepspeed_enabled ):
2693
2693
eval_outputs = self ._original_model .eval_forward (device_batch , self .state .outputs )
2694
2694
for metric in self .state .train_metrics .values ():
2695
2695
self ._original_model .update_metric (
@@ -3484,9 +3484,9 @@ def _eval_loop(
3484
3484
)[0 ]
3485
3485
3486
3486
self .engine .run_event (Event .EVAL_BEFORE_FORWARD )
3487
-
3487
+ precision = self . state . precision if self . state . precision is Precision . AMP_FP8 else Precision . AMP_BF16
3488
3488
with _get_precision_context (
3489
- self . state . precision ,
3489
+ precision ,
3490
3490
self .state .precision_config ,
3491
3491
self .state .deepspeed_enabled ,
3492
3492
):
@@ -3501,7 +3501,7 @@ def _eval_loop(
3501
3501
3502
3502
# Run in same precision context to avoid NaNs
3503
3503
with _get_precision_context (
3504
- self . state . precision ,
3504
+ precision ,
3505
3505
self .state .precision_config ,
3506
3506
self .state .deepspeed_enabled ,
3507
3507
):
0 commit comments