Skip to content

[FlexAttenion] Crash When Number of Variables Exceeds Assigned Values in Multi-Assignment #3420

Closed
@retonym

Description

@retonym

Describe the bug

Issue Description:

While testing the PyTorch flex-attention module, we encountered a scenario where the number of variables does not match the number of assigned values in a multi-assignment statement. For example:

a, b, c = 1, 1

Although this code is syntactically unusual , it executes successfully with the public Triton on Nvidia GPU. However, when running on Intel Triton, it fails with the error message:

IndexError('list index out of range')

We would like to request support for handling such cases in Intel Triton. Thank you for your assistance.

Steps to Reproduce:

Below is a minimal reproducible example that demonstrates the issue:

import triton
import triton.language as tl

import torch

device = 'xpu' if torch.xpu.is_available() else 'cuda'

@triton.jit
def test_varlen_list(A, BLOCK_SIZE: tl.constexpr):
  pid = tl.program_id(0)

  block_start = pid * BLOCK_SIZE

  offsets = block_start + tl.arange(0, BLOCK_SIZE)

  a, b, c = 1, 1

  tl.store(A + offsets, a)


BLOCK_SIZE = 32
TENSOR_SIZE = 512

A = torch.empty(TENSOR_SIZE, dtype=torch.float32, device=device)

grid = lambda meta: ((A.numel() + BLOCK_SIZE - 1) // BLOCK_SIZE,)

test_varlen_list[grid](A, BLOCK_SIZE)
print(A)

Environment details

PyTorch commit: 520079b986b76dc42e7fec3f992da3b1771e3192
Intel Triton commit: b7840ba
Public Triton commit:4b3bb1f8

Metadata

Metadata

Assignees

Type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions