Skip to content

[FlexAttention] Triton XPU didn't get correct value with the block io if the base address is not restricted aligned #3704

@chengjunlu

Description

@chengjunlu

Describe the bug

In the FlexDecoding test case, we found an issue that the block IO returns the in-correct matrix value if the base address is not aligned.

The Inductor code will generate the code like this:

    K_block_ptr = tl.make_block_ptr(
        base=K + k_offset,
        shape=(QK_HEAD_DIM, KV_LEN),                # (d, N)
        strides=(stride_kk, stride_kn),
        offsets=(0, off_n),
        block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
        order=(0, 1)
    )

It adds the offset directly into the base.

K_block_ptr base: 0xff000000002007ca
K_block_ptr shape:  [64]
K_block_ptr shape:  [2048]
K_block_ptr strides:  [1]
K_block_ptr strides:  [64]
K_block_ptr offsets:  [0]
K_block_ptr offsets:  [0]
K_block_ptr block_shape:  [64]
K_block_ptr block_shape:  [64]

Environment details

Triton XPU: Latest

Metadata

Metadata

Assignees

Type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions