Skip to content

Commit b3e7035

Browse files
authored
Use internal quantizer for input to the modules (#1551)
Internal quantizer for input to the modules Signed-off-by: Przemek Tredak <[email protected]>
1 parent 5bb771e commit b3e7035

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1358,7 +1358,7 @@ def _get_quantizers(self, fp8_output):
13581358
grad_output_quantizer = None
13591359
output_quantizer = None
13601360
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
1361-
input_quantizer.internal = False
1361+
input_quantizer.internal = True
13621362
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
13631363
weight_quantizer.internal = True
13641364
if fp8_output:

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1528,7 +1528,7 @@ def _get_quantizers(self):
15281528
) = [None] * 8
15291529
if self.fp8:
15301530
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
1531-
fc1_input_quantizer.internal = False # temporary
1531+
fc1_input_quantizer.internal = True
15321532
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
15331533
fc1_weight_quantizer.internal = True
15341534
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]

transformer_engine/pytorch/module/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1136,7 +1136,7 @@ def _get_quantizers(self, fp8_output, fp8_grad):
11361136
grad_output_quantizer = None
11371137
output_quantizer = None
11381138
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
1139-
input_quantizer.internal = False
1139+
input_quantizer.internal = True
11401140
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
11411141
weight_quantizer.internal = True
11421142
if fp8_output:

0 commit comments

Comments
 (0)