Skip to content

Commit e453e36

Browse files
张春立it-is-a-robot
authored andcommitted
fix(accuracy): cast input of exp(x) to fp32 in order to guarantee accuracy of sigmoid and softmax
1 parent 74728ae commit e453e36

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

triton_patch/python/triton_patch/language/standard.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@ def flip(x, dim=None):
2020
@jit
2121
@math._add_math_1arg_docstr("sigmoid")
2222
def sigmoid(x):
23-
assert core.constexpr(x.dtype.is_floating()), "Unexpected dtype"
24-
return 1 / (1 + math.exp(-x))
23+
_is_floating_type: core.constexpr = x.dtype.is_floating()
24+
core.static_assert(_is_floating_type == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(x.dtype)}")
25+
return (1 / (1 + math.exp(-x.to(core.float32)))).to(x.dtype)
2526

2627
@core._tensor_member_fn
2728
@jit
2829
@math._add_math_1arg_docstr("softmax")
2930
def softmax(x, ieee_rounding=False):
30-
assert core.constexpr(x.dtype.is_floating()), "Unexpected dtype"
31-
z = x - max(x, 0)
31+
_is_floating_type: core.constexpr = x.dtype.is_floating()
32+
core.static_assert(_is_floating_type == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(x.dtype)}")
33+
z = x.to(core.float32) - max(x, 0)
3234
num = math.exp(z)
3335
den = sum(num, 0)
34-
return math.fdiv(num, den, ieee_rounding)
36+
return math.fdiv(num, den, ieee_rounding).to(x.dtype)

0 commit comments

Comments
 (0)