Closed as not planned
Description
Describe the bug
Issue Description
During compiling llama 3.1 with advance path, we found it will throw a PassManager
error.
Reproducer
TRITON_INTEL_ADVANCED_PATH=1 python reproducer.py
import unittest
import torch
from torch.nn.attention.flex_attention import flex_attention
class TestFlexAttentionCompile(unittest.TestCase):
def setUp(self):
self.batch_size = 1
self.n_heads = 32
self.seq_length = 16
self.head_dim = 32
self.device = torch.device("xpu")
self.query = torch.randn(
self.batch_size,
self.n_heads,
self.seq_length,
self.head_dim,
device=self.device,
dtype=torch.float32,
)
self.key = torch.randn(
self.batch_size,
self.n_heads,
self.seq_length,
self.head_dim,
device=self.device,
dtype=torch.float32,
)
self.value = torch.randn(
self.batch_size,
self.n_heads,
self.seq_length,
self.head_dim,
device=self.device,
dtype=torch.float32,
)
def flex_attention_run(self, query, key, value):
def causal_mask(score, b, h, q_idx, kv_idx):
return torch.where(q_idx >= kv_idx, score, -float("inf"))
return flex_attention(
query,
key,
value,
score_mod=causal_mask,
return_lse=True,
)
def test_compile_flex_attention(self):
compiled_flex_attn = torch.compile(
self.flex_attention_run,
)
out_uncompiled = compiled_flex_attn(
self.query,
self.key,
self.value,
)
print(out_uncompiled)
if __name__ == "__main__":
unittest.main()
Error msg
======================================================================
ERROR: test_compile_flex_attention (__main__.TestFlexAttentionCompile)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/workspace1/xingyuan/20250213-flexatt-enable/reproduce-adv-path-compile.py", line 60, in test_compile_flex_attention
out_uncompiled = compiled_flex_attn(
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_dynamo/eval_frame.py", line 574, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_dynamo/output_graph.py", line 1487, in _call_user_compiler
raise BackendCompilerFailed(
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_dynamo/output_graph.py", line 1466, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_dynamo/repro/after_dynamo.py", line 131, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/__init__.py", line 2339, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/compile_fx.py", line 2163, in compile_fx
return aot_autograd(
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_dynamo/backends/common.py", line 83, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_functorch/aot_autograd.py", line 1158, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 779, in load
compiled_fn = dispatch_and_compile()
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_functorch/aot_autograd.py", line 1143, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_functorch/aot_autograd.py", line 820, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 205, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_functorch/aot_autograd.py", line 479, in __call__
return self.compiler_fn(gm, example_inputs)
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/compile_fx.py", line 2038, in fw_compiler_base
return inner_compile(
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/compile_fx.py", line 623, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_dynamo/repro/after_aot.py", line 104, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/compile_fx.py", line 712, in _compile_fx_inner
mb_compiled_graph, cache_info = FxGraphCache.load_with_key(
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/codecache.py", line 1287, in load_with_key
compiled_graph, cache_info = FxGraphCache._lookup_graph(
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/codecache.py", line 1055, in _lookup_graph
artifact_path = graph.after_deserialization(constants)
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/output_code.py", line 554, in after_deserialization
self.current_callable = PyCodeCache.load_by_key_path(
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/codecache.py", line 2757, in load_by_key_path
mod = _reload_python_module(key, path)
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/runtime/compile_tasks.py", line 51, in _reload_python_module
exec(code, mod.__dict__, mod.__dict__)
File "/workspace1/xingyuan/20250213-flexatt-enable/torchinductor_cache/vl/cvlvejmhrxqgkvn3ia4tojadysgyykk7h2kuqwfxpc6ltvkyoapo.py", line 50, in <module>
triton_poi_fused_ones_0 = async_compile.triton('triton_poi_fused_ones_0', '''
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/async_compile.py", line 254, in triton
kernel.precompile()
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/runtime/triton_heuristics.py", line 265, in precompile
self._precompile_worker()
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/runtime/triton_heuristics.py", line 280, in _precompile_worker
compile_results.append(self._precompile_config(c))
File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/runtime/triton_heuristics.py", line 518, in _precompile_config
binary = triton.compile(*compile_args, **compile_kwargs)
File "/workspace1/xingyuan/20250213-flexatt-enable/intel-xpu-backend-for-triton/python/triton/compiler/compiler.py", line 285, in compile
next_module = compile_ir(module, metadata)
File "/workspace1/xingyuan/20250213-flexatt-enable/intel-xpu-backend-for-triton/python/triton/backends/intel/compiler.py", line 431, in <lambda>
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.properties)
File "/workspace1/xingyuan/20250213-flexatt-enable/intel-xpu-backend-for-triton/python/triton/backends/intel/compiler.py", line 272, in make_ttgir
return XPUBackend.AdvancedPath.make_ttgir(mod, metadata, opt)
File "/workspace1/xingyuan/20250213-flexatt-enable/intel-xpu-backend-for-triton/python/triton/backends/intel/compiler.py", line 122, in make_ttgir
pm.run(mod)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: PassManager::run failed
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
Environment details
torch 2.7.0a0+gitbbc1fc4
triton 3.2.0+gitd27457fd