File tree Expand file tree Collapse file tree 1 file changed +7
-5
lines changed
triton_patch/python/triton_patch/language Expand file tree Collapse file tree 1 file changed +7
-5
lines changed Original file line number Diff line number Diff line change @@ -20,15 +20,17 @@ def flip(x, dim=None):
20
20
@jit
21
21
@math ._add_math_1arg_docstr ("sigmoid" )
22
22
def sigmoid (x ):
23
- assert core .constexpr (x .dtype .is_floating ()), "Unexpected dtype"
24
- return 1 / (1 + math .exp (- x ))
23
+ _is_floating_type : core .constexpr = x .dtype .is_floating ()
24
+ core .static_assert (_is_floating_type == True , f"Expected dtype fp16/fp32/bf16, but got { core .constexpr (x .dtype )} " )
25
+ return (1 / (1 + math .exp (- x .to (core .float32 )))).to (x .dtype )
25
26
26
27
@core ._tensor_member_fn
27
28
@jit
28
29
@math ._add_math_1arg_docstr ("softmax" )
29
30
def softmax (x , ieee_rounding = False ):
30
- assert core .constexpr (x .dtype .is_floating ()), "Unexpected dtype"
31
- z = x - max (x , 0 )
31
+ _is_floating_type : core .constexpr = x .dtype .is_floating ()
32
+ core .static_assert (_is_floating_type == True , f"Expected dtype fp16/fp32/bf16, but got { core .constexpr (x .dtype )} " )
33
+ z = x .to (core .float32 ) - max (x , 0 )
32
34
num = math .exp (z )
33
35
den = sum (num , 0 )
34
- return math .fdiv (num , den , ieee_rounding )
36
+ return math .fdiv (num , den , ieee_rounding ). to ( x . dtype )
You can’t perform that action at this time.
0 commit comments