Skip to content

Commit 9f5ad4f

Browse files
melo882it-is-a-robot
authored andcommitted
test(op): modify bool support list for some ops
1 parent ba1e078 commit 9f5ad4f

File tree

10 files changed

+49
-83
lines changed

10 files changed

+49
-83
lines changed

ascend/examples/generalization_cases/test_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ def generate_tensor_int_withSigns(shape, dtype):
145145
return torch.randint(low=-32768, high=32767, size=shape, dtype=eval('torch.' + dtype))
146146
elif dtype == 'int8':
147147
return torch.randint(low=-128, high=127, size=shape, dtype=eval('torch.' + dtype))
148+
elif dtype == 'bool':
149+
return torch.randint(low=0, high=2, size=shape).bool()
148150
else:
149151
raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype))
150152

ascend/examples/generalization_cases/test_eq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def triton_eq_4d_5d(
6363

6464

6565
@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d)
66-
@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'])
66+
@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'])
6767
def test_eq(shape, dtype):
6868
logging.debug(f'dtype:{dtype} shape:{shape}')
6969
# 生成数据

ascend/examples/generalization_cases/test_ge_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def triton_ge_4d_5d(
8181
tl.store(output_ptr + offsets, ret, mask=masks)
8282

8383

84-
typelist = ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']
84+
typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']
8585

8686
dtype_mapping = {
8787
'int8': (torch.int8),

ascend/examples/generalization_cases/test_general_floordiv.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def triton_floordiv_4d_5d(
6666

6767

6868
@pytest.mark.parametrize('shape', TestUtils.full_shape) # some shape with int8 over ub
69-
@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64'])
69+
@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64'])
7070
def test_floordiv(shape, dtype):
7171
logging.log(logging.DEBUG, f"shape = {shape}")
7272
x = test_common.generate_tensor_int_withSigns(shape, dtype).npu()
@@ -83,38 +83,21 @@ def test_floordiv(shape, dtype):
8383
ans = ans + ans_mask
8484

8585
if len(shape) == 1:
86-
XB = 1
87-
xnumel = 1
88-
YB = 1
89-
ynumel = 1
90-
ZB = shape[0]
91-
znumel = shape[0]
86+
triton_floordiv[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0])
9287
elif len(shape) == 2:
93-
XB = 1
94-
xnumel = 1
95-
YB = shape[0]
96-
ynumel = shape[0]
97-
ZB = shape[1]
98-
znumel = shape[1]
99-
else:
100-
XB = shape[0]
101-
xnumel = shape[0]
102-
YB = shape[1]
103-
ynumel = shape[1]
104-
ZB = shape[2]
105-
znumel = shape[2]
106-
107-
grid = (1, 1, 1)
108-
if dtype == 'int8':
109-
if x.numel() * x.element_size() >= 512:
110-
grid = (1, 1, ZB)
111-
ZB = 1
88+
if shape[0] > shape[1]:
89+
triton_floordiv[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1])
90+
else:
91+
triton_floordiv[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1])
92+
elif len(shape) == 3:
93+
if max(shape[0], shape[1], shape[2]) == shape[0]:
94+
triton_floordiv[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2])
95+
elif max(shape[0], shape[1], shape[2]) == shape[1]:
96+
triton_floordiv[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2])
97+
else:
98+
triton_floordiv[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2])
11299
else:
113-
if x.numel() * x.element_size() >= 8192:
114-
grid = (1, 1, ZB)
115-
ZB = 1
116-
117-
triton_floordiv[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel)
100+
triton_floordiv[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1)
118101

119102
test_common.validate_cmp(dtype, ans, output)
120103

ascend/examples/generalization_cases/test_gt_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def triton_gt_4d_5d(
8181
tl.store(output_ptr + offsets, ret, mask=masks)
8282

8383

84-
typelist = ['int8','int16','int32','int64','float16','bfloat16','float32']
84+
typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']
8585

8686
dtype_mapping = {
8787
'int8': (torch.int8),

ascend/examples/generalization_cases/test_le_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def triton_le_4d_5d(
8181
tl.store(output_ptr + offsets, ret, mask=masks)
8282

8383

84-
typelist = ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']
84+
typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']
8585

8686
dtype_mapping = {
8787
'int8': (torch.int8),

ascend/examples/generalization_cases/test_lt_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def triton_lt_4d_5d(
8181
tl.store(output_ptr + offsets, ret, mask=masks)
8282

8383

84-
typelist = ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']
84+
typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']
8585

8686
dtype_mapping = {
8787
'int8': (torch.int8),

ascend/examples/generalization_cases/test_mod.py

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -79,46 +79,26 @@ def test_case2(dtype, shape):
7979
new_shape = shape
8080
z[z <= 0] = 1
8181

82-
output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu()
83-
output1 = output
84-
logging.debug(f"output.dtype={output.dtype}")
85-
8682
ans = torch_pointwise(x.cpu(), y.cpu())
8783
ans = ans.npu()
84+
output = torch.zeros_like(ans)
8885

8986
if len(shape) == 1:
90-
XB = 1
91-
xnumel = 1
92-
YB = 1
93-
ynumel = 1
94-
ZB = shape[0]
95-
znumel = shape[0]
87+
fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0])
9688
elif len(shape) == 2:
97-
XB = 1
98-
xnumel = 1
99-
YB = shape[0]
100-
ynumel = shape[0]
101-
ZB = shape[1]
102-
znumel = shape[1]
103-
else:
104-
XB = shape[0]
105-
xnumel = shape[0]
106-
YB = shape[1]
107-
ynumel = shape[1]
108-
ZB = shape[2]
109-
znumel = shape[2]
110-
111-
grid = (1, 1, 1)
112-
if dtype == 'int8':
113-
if x.numel() * x.element_size() >= 512:
114-
grid = (1, 1, ZB)
115-
ZB = 1
89+
if shape[0] > shape[1]:
90+
fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1])
91+
else:
92+
fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1])
93+
elif len(shape) == 3:
94+
if max(shape[0], shape[1], shape[2]) == shape[0]:
95+
fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2])
96+
elif max(shape[0], shape[1], shape[2]) == shape[1]:
97+
fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2])
98+
else:
99+
fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2])
116100
else:
117-
if x.numel() * x.element_size() >= 8192:
118-
grid = (1, 1, ZB)
119-
ZB = 1
120-
121-
fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel)
101+
fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1)
122102

123103
test_common.validate_cmp(dtype, ans, output)
124104

ascend/examples/generalization_cases/test_ne.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def triton_ne_4d_5d(
6363

6464

6565
@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d)
66-
@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'])
66+
@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'])
6767
def test_ne(shape, dtype):
6868
logging.debug(f'dtype:{dtype} shape:{shape}')
6969
# 生成数据

docs/sources/python-api/outline.md

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
| | make_block_ptr |||| × ||||| × |
2929
| | advance |||| × ||||| × |
3030
| Indexing Ops | flip |||| × ||||||
31-
| | where |||| × | × |||| |
31+
| | where |||| × | × ||||* |
3232
| | swizzle2d |||| × || × | × | × | × |
33-
| Math Ops | add |||| × ||||| |
34-
| | sub |||| × ||||| |
35-
| | mul |||| × ||||| |
36-
| | div |||| × ||||| |
33+
| Math Ops | add |||| × |||||* |
34+
| | sub |||| × |||||* |
35+
| | mul |||| × |||||* |
36+
| | div |||| × |||||* |
3737
| | floordiv(//) |||| × || × | × | × | × |
3838
| | mod |||| × | × | × | × | × | × |
3939
| | neg |||| × ||||| × |
@@ -44,12 +44,12 @@
4444
| | not(~) |||| × || × | × | × ||
4545
| | lshift(<<) |||| × || × | × | × | × |
4646
| | rshift(>>) |||| × || × | × | × | × |
47-
| | gt |||| × ||||| × |
48-
| | ge |||| × ||||| × |
49-
| | lt |||| × ||||| × |
50-
| | le |||| × ||||| × |
51-
| | eq |||| × ||||| × |
52-
| | ne |||| × ||||| × |
47+
| | gt |||| × ||||| * |
48+
| | ge |||| × ||||| * |
49+
| | lt |||| × ||||| * |
50+
| | le |||| × ||||| * |
51+
| | eq |||| × ||||| * |
52+
| | ne |||| × ||||| * |
5353
| | logical and | × | × | × | × | × | × | × | × ||
5454
| | logical or | × | × | × | × | × | × | × | × ||
5555
| | abs |||| × |||||* |
@@ -138,4 +138,5 @@
138138

139139
- ALL: int8类型由于特殊处理,会占用更大的片上空间,编译时容易造成ub overflow报错,通常调整tilling即可解决;
140140
triton kernel中同时存在所有tensor总和不能超过96KB,若关闭double buffer,则不能超过192KB;
141-
所有tensor不允许某个shape的size小于1。
141+
所有tensor不允许某个shape的size小于1;
142+
*表示triton内部将bool类型转为int8类型进行运算,并能够执行得到结果的OP。

0 commit comments

Comments
 (0)