Skip to content

TRITON_INTEL_ADVANCED_PATH not support flex_attention #3486

Closed as not planned
Closed as not planned
@hoshibara

Description

@hoshibara

Describe the bug

Issue Description

During compiling llama 3.1 with advance path, we found it will throw a PassManager error.

Reproducer

TRITON_INTEL_ADVANCED_PATH=1 python reproducer.py
import unittest
import torch

from torch.nn.attention.flex_attention import flex_attention


class TestFlexAttentionCompile(unittest.TestCase):
    def setUp(self):
        self.batch_size = 1
        self.n_heads = 32
        self.seq_length = 16
        self.head_dim = 32
        self.device = torch.device("xpu")

        self.query = torch.randn(
            self.batch_size,
            self.n_heads,
            self.seq_length,
            self.head_dim,
            device=self.device,
            dtype=torch.float32,
        )
        self.key = torch.randn(
            self.batch_size,
            self.n_heads,
            self.seq_length,
            self.head_dim,
            device=self.device,
            dtype=torch.float32,
        )
        self.value = torch.randn(
            self.batch_size,
            self.n_heads,
            self.seq_length,
            self.head_dim,
            device=self.device,
            dtype=torch.float32,
        )

    def flex_attention_run(self, query, key, value):
        def causal_mask(score, b, h, q_idx, kv_idx):
            return torch.where(q_idx >= kv_idx, score, -float("inf"))

        return flex_attention(
            query,
            key,
            value,
            score_mod=causal_mask,
            return_lse=True,
        )

    def test_compile_flex_attention(self):
        compiled_flex_attn = torch.compile(
            self.flex_attention_run,
        )

        out_uncompiled = compiled_flex_attn(
            self.query,
            self.key,
            self.value,
        )
        print(out_uncompiled)


if __name__ == "__main__":
    unittest.main()

Error msg

======================================================================
ERROR: test_compile_flex_attention (__main__.TestFlexAttentionCompile)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/workspace1/xingyuan/20250213-flexatt-enable/reproduce-adv-path-compile.py", line 60, in test_compile_flex_attention
    out_uncompiled = compiled_flex_attn(
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_dynamo/eval_frame.py", line 574, in _fn
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_dynamo/output_graph.py", line 1487, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_dynamo/output_graph.py", line 1466, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_dynamo/repro/after_dynamo.py", line 131, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/__init__.py", line 2339, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/compile_fx.py", line 2163, in compile_fx
    return aot_autograd(
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_dynamo/backends/common.py", line 83, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_functorch/aot_autograd.py", line 1158, in aot_module_simplified
    compiled_fn = AOTAutogradCache.load(
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 779, in load
    compiled_fn = dispatch_and_compile()
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_functorch/aot_autograd.py", line 1143, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_functorch/aot_autograd.py", line 820, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 205, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_functorch/aot_autograd.py", line 479, in __call__
    return self.compiler_fn(gm, example_inputs)
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/compile_fx.py", line 2038, in fw_compiler_base
    return inner_compile(
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/compile_fx.py", line 623, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_dynamo/repro/after_aot.py", line 104, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/compile_fx.py", line 712, in _compile_fx_inner
    mb_compiled_graph, cache_info = FxGraphCache.load_with_key(
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/codecache.py", line 1287, in load_with_key
    compiled_graph, cache_info = FxGraphCache._lookup_graph(
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/codecache.py", line 1055, in _lookup_graph
    artifact_path = graph.after_deserialization(constants)
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/output_code.py", line 554, in after_deserialization
    self.current_callable = PyCodeCache.load_by_key_path(
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/codecache.py", line 2757, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/runtime/compile_tasks.py", line 51, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/workspace1/xingyuan/20250213-flexatt-enable/torchinductor_cache/vl/cvlvejmhrxqgkvn3ia4tojadysgyykk7h2kuqwfxpc6ltvkyoapo.py", line 50, in <module>
    triton_poi_fused_ones_0 = async_compile.triton('triton_poi_fused_ones_0', '''
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/async_compile.py", line 254, in triton
    kernel.precompile()
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/runtime/triton_heuristics.py", line 265, in precompile
    self._precompile_worker()
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/runtime/triton_heuristics.py", line 280, in _precompile_worker
    compile_results.append(self._precompile_config(c))
  File "/workspace1/xingyuan/20250213-flexatt-enable/hoshibara-pytorch/torch/_inductor/runtime/triton_heuristics.py", line 518, in _precompile_config
    binary = triton.compile(*compile_args, **compile_kwargs)
  File "/workspace1/xingyuan/20250213-flexatt-enable/intel-xpu-backend-for-triton/python/triton/compiler/compiler.py", line 285, in compile
    next_module = compile_ir(module, metadata)
  File "/workspace1/xingyuan/20250213-flexatt-enable/intel-xpu-backend-for-triton/python/triton/backends/intel/compiler.py", line 431, in <lambda>
    stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.properties)
  File "/workspace1/xingyuan/20250213-flexatt-enable/intel-xpu-backend-for-triton/python/triton/backends/intel/compiler.py", line 272, in make_ttgir
    return XPUBackend.AdvancedPath.make_ttgir(mod, metadata, opt)
  File "/workspace1/xingyuan/20250213-flexatt-enable/intel-xpu-backend-for-triton/python/triton/backends/intel/compiler.py", line 122, in make_ttgir
    pm.run(mod)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: PassManager::run failed

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Environment details

torch                   2.7.0a0+gitbbc1fc4
triton                  3.2.0+gitd27457fd

Metadata

Metadata

Assignees

Type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions