Skip to content

Commit 8e62837

Browse files
author
Chuck Tang
committed
commit change
1 parent 76dacf5 commit 8e62837

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

composer/trainer/trainer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2686,10 +2686,10 @@ def _train_loop(self) -> None:
26862686
def _eval_train_metrics(self, device_batch):
26872687
assert self._train_data_spec is not None, 'The train data spec should be set on __init__ or fit()'
26882688
assert self.state.train_metrics is not None, 'The train metrics should be set on __init__ or fit()'
2689-
2689+
precision = self.state.precision if self.state.precision is Precision.AMP_FP8 else Precision.AMP_BF16
26902690
with torch.no_grad(),\
26912691
model_eval_mode(self.state.model),\
2692-
_get_precision_context(self.state.precision, self.state.precision_config, self.state.deepspeed_enabled):
2692+
_get_precision_context(precision, self.state.precision_config, self.state.deepspeed_enabled):
26932693
eval_outputs = self._original_model.eval_forward(device_batch, self.state.outputs)
26942694
for metric in self.state.train_metrics.values():
26952695
self._original_model.update_metric(
@@ -3484,9 +3484,9 @@ def _eval_loop(
34843484
)[0]
34853485

34863486
self.engine.run_event(Event.EVAL_BEFORE_FORWARD)
3487-
3487+
precision = self.state.precision if self.state.precision is Precision.AMP_FP8 else Precision.AMP_BF16
34883488
with _get_precision_context(
3489-
self.state.precision,
3489+
precision,
34903490
self.state.precision_config,
34913491
self.state.deepspeed_enabled,
34923492
):
@@ -3501,7 +3501,7 @@ def _eval_loop(
35013501

35023502
# Run in same precision context to avoid NaNs
35033503
with _get_precision_context(
3504-
self.state.precision,
3504+
precision,
35053505
self.state.precision_config,
35063506
self.state.deepspeed_enabled,
35073507
):

0 commit comments

Comments
 (0)