Skip to content

Commit 3ca2f49

Browse files
authored
[FRONTEND] Fix @triton.jit(debug=True) (#5037)
Reported in https://github.com/triton-lang/triton/pull/5033/files#r1825438371 This was broken by #4589
1 parent d0db12b commit 3ca2f49

File tree

3 files changed

+18
-12
lines changed

3 files changed

+18
-12
lines changed

python/test/unit/test_debug.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,33 @@
44
import triton.language as tl
55
import triton
66

7-
@pytest.mark.parametrize('cond, opt_flag, env_var', [
8-
(cond, opt_flag, env_var) for cond in [True, False] \
9-
for opt_flag in [True, False] \
10-
for env_var in [True, False]\
11-
])
7+
8+
@pytest.mark.parametrize('cond', [True, False])
9+
@pytest.mark.parametrize('opt_flag', [True, False, None])
10+
@pytest.mark.parametrize('env_var', [True, False])
11+
@pytest.mark.parametrize('jit_flag', [True, False])
1212
@pytest.mark.forked
13-
def test_device_assert(cond, opt_flag, env_var, device):
13+
def test_device_assert(cond, opt_flag, env_var, jit_flag, device):
1414
os.environ['TRITON_DEBUG'] = str(int(env_var))
1515
torch.zeros([1], dtype=torch.int32, device=device)
1616

17-
@triton.jit
17+
@triton.jit(debug=jit_flag)
1818
def _kernel(COND: tl.constexpr):
1919
tl.device_assert(COND, 'test')
2020

21-
if not cond and (opt_flag or env_var):
21+
is_debug = env_var or (opt_flag if opt_flag is not None else jit_flag)
22+
23+
kwargs = {}
24+
if opt_flag is not None:
25+
kwargs["debug"] = opt_flag
26+
27+
if not cond and is_debug:
2228
with pytest.raises(RuntimeError):
23-
_kernel[(1, )](cond, debug=opt_flag)
29+
_kernel[(1, )](cond, **kwargs)
2430
getattr(torch, device).synchronize()
2531
return
2632

27-
_kernel[(1, )](cond, debug=opt_flag)
33+
_kernel[(1, )](cond, **kwargs)
2834
getattr(torch, device).synchronize()
2935

3036

python/triton/language/core.py

-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def builtin(fn: T) -> T:
2929
@wraps(fn)
3030
def wrapper(*args, **kwargs):
3131
if "_builder" not in kwargs or kwargs["_builder"] is None:
32-
print(kwargs)
3332
raise ValueError("Did you forget to add @triton.jit ? "
3433
"(`_builder` argument must be provided outside of JIT functions.)")
3534
return fn(*args, **kwargs)

python/triton/runtime/jit.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def create_binder(self, backend):
561561
]
562562

563563
def run(self, *args, grid, warmup, **kwargs):
564-
kwargs["debug"] = kwargs.get("debug", False) or os.environ.get("TRITON_DEBUG", "0") == "1"
564+
kwargs["debug"] = kwargs.get("debug", self.debug) or os.environ.get("TRITON_DEBUG", "0") == "1"
565565

566566
# parse options
567567
from ..compiler import make_backend
@@ -698,6 +698,7 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
698698
# JITFunction can be instantiated as kernel
699699
# when called with a grid using __getitem__
700700
self.kernel = None
701+
self.debug = debug
701702
self.noinline = noinline
702703

703704
# TODO(jlebar): Remove uses of these fields outside this file, then

0 commit comments

Comments
 (0)