|
4 | 4 | import triton.language as tl
|
5 | 5 | import triton
|
6 | 6 |
|
7 |
| -@pytest.mark.parametrize('cond, opt_flag, env_var', [ |
8 |
| - (cond, opt_flag, env_var) for cond in [True, False] \ |
9 |
| - for opt_flag in [True, False] \ |
10 |
| - for env_var in [True, False]\ |
11 |
| -]) |
| 7 | + |
| 8 | +@pytest.mark.parametrize('cond', [True, False]) |
| 9 | +@pytest.mark.parametrize('opt_flag', [True, False, None]) |
| 10 | +@pytest.mark.parametrize('env_var', [True, False]) |
| 11 | +@pytest.mark.parametrize('jit_flag', [True, False]) |
12 | 12 | @pytest.mark.forked
|
13 |
| -def test_device_assert(cond, opt_flag, env_var, device): |
| 13 | +def test_device_assert(cond, opt_flag, env_var, jit_flag, device): |
14 | 14 | os.environ['TRITON_DEBUG'] = str(int(env_var))
|
15 | 15 | torch.zeros([1], dtype=torch.int32, device=device)
|
16 | 16 |
|
17 |
| - @triton.jit |
| 17 | + @triton.jit(debug=jit_flag) |
18 | 18 | def _kernel(COND: tl.constexpr):
|
19 | 19 | tl.device_assert(COND, 'test')
|
20 | 20 |
|
21 |
| - if not cond and (opt_flag or env_var): |
| 21 | + is_debug = env_var or (opt_flag if opt_flag is not None else jit_flag) |
| 22 | + |
| 23 | + kwargs = {} |
| 24 | + if opt_flag is not None: |
| 25 | + kwargs["debug"] = opt_flag |
| 26 | + |
| 27 | + if not cond and is_debug: |
22 | 28 | with pytest.raises(RuntimeError):
|
23 |
| - _kernel[(1, )](cond, debug=opt_flag) |
| 29 | + _kernel[(1, )](cond, **kwargs) |
24 | 30 | getattr(torch, device).synchronize()
|
25 | 31 | return
|
26 | 32 |
|
27 |
| - _kernel[(1, )](cond, debug=opt_flag) |
| 33 | + _kernel[(1, )](cond, **kwargs) |
28 | 34 | getattr(torch, device).synchronize()
|
29 | 35 |
|
30 | 36 |
|
|
0 commit comments