Skip to content

Commit a43e593

Browse files
wangzhanpeng5it-is-a-robot
authored andcommitted
test(op): The zeros operator adds an ub overflow constraint and fix matmul of Floor Division
1 parent 73fad58 commit a43e593

File tree

3 files changed

+5
-2
lines changed

3 files changed

+5
-2
lines changed

ascend/examples/generalization_cases/test_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_matmul(shape, dtype):
6161
# bisheng not support yet
6262
if M % 16 != 0 or N % 16 != 0 or get_dtype_size(dtype) * K % 32 != 0:
6363
return
64-
kalign = 32 / get_dtype_size(dtype) # 32byte/Dtype_bytes
64+
kalign = 32 // get_dtype_size(dtype) # 32byte/Dtype_bytes
6565
BLOCK_M, BLOCK_N, BLOCK_K = min(max(M, 16), 32), min(max(N, 16), 32), min(max(K, kalign), 32)
6666
a = test_common.generate_tensor((M, K), dtype)
6767
b = test_common.generate_tensor((K, N), dtype)

ascend/examples/generalization_cases/test_zeros_op.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,8 @@ def fn_npu_multi_d(output_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.conste
519519
)
520520
def test_case_4d_5d(param_list):
521521
dtype, shape = param_list
522-
522+
if check_ub_mem_overflow(sigtype, shape):
523+
pytest.skip(f"dtype:{sigtype} shape:{shape} mem overflow")
523524
y_ref = torch.full(shape, 0, dtype=eval('torch.' + dtype)).npu()
524525
print(f"y_ref = {torch.flatten(y_ref)[0:4]}")
525526

ascend/examples/generalization_cases/test_zeroslike.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ def fn_npu_multi_d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl
136136
)
137137
def test_case_4d_5d(param_list):
138138
dtype, shape = param_list
139+
if check_ub_mem_overflow(dtype, shape):
140+
return
139141
x0 = test_common.generate_tensor(shape, dtype)
140142
y_ref = torch.zeros_like(x0, dtype=eval('torch.' + dtype)).npu()
141143
print(f"y_ref = {torch.flatten(y_ref)[0:4]}")

0 commit comments

Comments
 (0)