Closed
Description
Describe the bug
Issue Description:
While testing the PyTorch flex-attention module , we observed an accuracy issue in the unit test when running on the XPU device. The same test passes successfully on the CUDA device.
Upon further investigation, we found that the output values produced by the same Triton kernel differ between the XPU and CUDA devices, though the input tensors are identical.
Steps to Reproduce:
Due to github file limitation, please refer to the link TRITONXPU-175 to get the reproducing scripts running on xpu and cuda, and the input tensors.
Reproducing script for XPU device:
# AOT ID: ['0_forward']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
import torch._inductor.kernel.flex_attention
from torch._C import _xpu_getCurrentRawStream as get_raw_stream
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import (
grid,
split_scan_grid,
grid_combo_kernels,
start_graph,
end_graph,
cooperative_reduction_grid,
)
from torch._C import _xpu_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
# kernel path: /home/yunfei/code/docker_folder/pytorch/torchinductor_cache/s4/cs4r72zltht2fd6b4xaq4vnmlkpx6hbtsl6ikpxz2sayfmxw6bv7.py
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
# Source node to ATen node mapping:
# flex_attention => flex_attention
# Graph fragment:
# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (1, 1, %primals_4, %primals_5, None, None, %primals_6, %primals_7, None, None, 1073741824, 1073741824, %sdpa_mask0), 0.125, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (), ()), kwargs = {})
triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
@triton_heuristics.template(
num_stages=1,
num_warps=8,
triton_meta={'signature': {'arg_Q': '*fp16', 'arg_K': '*fp16', 'arg_V': '*fp16', 'arg_LSE': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*fp32', 'arg_FULL_KV_IDX': '*fp32', 'out_ptr0': '*fp16'}, 'device': DeviceProperties(type='xpu', index=0, multi_processor_count=56, cc={'architecture': 13136561920, 'driver_version': '1.6.31294+13', 'gpu_eu_count': 448, 'gpu_subslice_count': 56, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 448, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1100', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 51539607552, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.60.7'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, warp_size=32), 'constants': {}, 'configs': [{(3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'A3F77046C72CF77D846D2C5812F82358A8B1BF98607E9C275E4719740A017E73', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
)
@triton.jit
def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'tf32'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.125
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = False
QK_HEAD_DIM : tl.constexpr = 64
QK_HEAD_DIM_ROUNDED : tl.constexpr = 64
V_HEAD_DIM : tl.constexpr = 64
V_HEAD_DIM_ROUNDED : tl.constexpr = 64
SAFE_HEAD_DIM : tl.constexpr = True
BLOCK_M : tl.constexpr = 128
BLOCK_N : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 1073741824
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 1073741824
ALLOW_TF32 : tl.constexpr = False
Q = arg_Q
K = arg_K
V = arg_V
LSE = arg_LSE
KV_NUM_BLKS = arg_KV_NUM_BLKS
KV_IDX = arg_KV_IDX
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
FULL_KV_IDX = arg_FULL_KV_IDX
# Sub notation for this kernel:
#
# Q: Query, K: Key, V: Value
# M: Number of queries, N: Number of keys/values, D: Model dimension
# QK_HEAD_DIM: The dimension of the query and key embeddings
# V_HEAD_DIM: The dimension of the value embeddings
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
#
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
#
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
#
# (Modifiable) Performance tuning options
# BLOCK_M: The thread block size across the seqlen dim of Q.
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
# The below are kernel options that can be applied for certain score_mods,
# or involve a numerics vs. perf tradeoff
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
# about 20% more numerical error, but slightly faster.
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
# is not masked out? If so, we can skip an extra safety check
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
# contiguous? If so, we don't need to do an indirect jump for every block
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
# Define strides of inputs
stride_qz, stride_qh, stride_qm, stride_qk = 1048576, 131072, 64, 1
stride_kz, stride_kh, stride_kn, stride_kk = 1048576, 131072, 64, 1
stride_vz, stride_vh, stride_vn, stride_vk = 1048576, 131072, 64, 1
ZQ = 4
HQ = 8
Q_LEN = 1024
ZKV = 4
KV_LEN = 2048
MATMUL_PRECISION = Q.dtype.element_ty
q_start = tl.program_id(0)
off_zq = tl.program_id(1) // HQ
off_hq = tl.program_id(1) % HQ
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
off_zkv = off_zq % ZKV
off_hkv = off_hq // GQA_SHARED_HEADS
off_g = off_hq % GQA_SHARED_HEADS
q_offset = off_zq * stride_qz + off_hq * stride_qh
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
Q = Q + q_offset
K = K + k_offset
V = V + v_offset
SPARSE_Z = 1
SPARSE_HQ = 1
sparse_idx_z = off_zq % SPARSE_Z
sparse_idx_hq = off_hq % SPARSE_HQ
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
stride_kv_num_blks_h = 1
stride_kv_idx_h = 1
stride_kv_idx_m = 1
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
# KV_IDX and KV_NUM_BLKS are always contiguous.
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
Q_block_ptr = tl.make_block_ptr(
base=Q,
shape=(Q_LEN, QK_HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(q_start * BLOCK_M, 0),
block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED),
order=(1, 0)
)
q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We don't know anything "special" about these blocks, so we need to apply
# both score_mod and mask_mod to it
kv_indices = KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
K_block_ptr = tl.make_block_ptr(
base=K,
shape=(QK_HEAD_DIM, KV_LEN),
strides=(stride_kk, stride_kn),
offsets=(0, kv_start),
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V,
shape=(KV_LEN, V_HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(kv_start, 0),
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
order=(1, 0)
)
offs_n = kv_start + tl.arange(0, BLOCK_N)
acc, l_i, m_i = forward_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
acc, l_i, m_i,
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
kv_indices, kv_num_blocks,
0, block_n_end,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We know these blocks are guaranteed to be "full", so we don't need to
# apply mask_mod to them - only score_mod
if HAS_FULL_BLOCKS:
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
K_block_ptr = tl.make_block_ptr(
base=K,
shape=(QK_HEAD_DIM, KV_LEN),
strides=(stride_kk, stride_kn),
offsets=(0, kv_start),
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V,
shape=(KV_LEN, V_HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(kv_start, 0),
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
order=(1, 0)
)
offs_n = kv_start + tl.arange(0, BLOCK_N)
acc, l_i, m_i = forward_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
acc, l_i, m_i,
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
kv_indices, kv_num_blocks,
0, block_n_end,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# [Note] Handle fully masked out rows:
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
l_i = tl.where(l_i == 0.0, 1, l_i)
acc = acc / l_i[:, None]
idx_zq = tl.program_id(1) // HQ
idx_hq = tl.program_id(1) % HQ
idx_m = offs_m[:, None]
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :]
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
xindex = idx_d + 64*idx_m + 65536*idx_hq + 524288*idx_zq
tl.store(out_ptr0 + (tl.broadcast_to(xindex, acc.shape)), acc, mask)
if OUTPUT_LOGSUMEXP:
off_hz = tl.program_id(1)
l_ptrs = LSE + off_hz * Q_LEN + offs_m
lse = m_i + tl.math.log2(l_i)
if IS_DIVISIBLE:
tl.store(l_ptrs, lse)
else:
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
@triton.jit
def forward_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets used as inputs to score_mod & mask_mod
# of size [BLOCK_M, BLOCK_N] or scalar.
off_z, off_h, offs_m, offs_n,
# blocksparse data
kv_indices, kv_num_blocks,
# start kv and end kv block
block_n_start, block_n_end,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'tf32'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.125
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = False
QK_HEAD_DIM : tl.constexpr = 64
QK_HEAD_DIM_ROUNDED : tl.constexpr = 64
V_HEAD_DIM : tl.constexpr = 64
V_HEAD_DIM_ROUNDED : tl.constexpr = 64
SAFE_HEAD_DIM : tl.constexpr = True
BLOCK_M : tl.constexpr = 128
BLOCK_N : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 1073741824
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 1073741824
ALLOW_TF32 : tl.constexpr = False
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
RCP_LN2: tl.constexpr = 1.44269504
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
# loop over k, v and update accumulator until block_n_end
for start_n in range(block_n_start, block_n_end):
if IS_DIVISIBLE:
acc, l_i, m_i = forward_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
else:
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
# it's on par or slightly faster than only applying to the last block in fwd.
# However, we choose different strategy for bwd, where we only apply mod & mask
# to the last block because it's faster a lot.
acc, l_i, m_i = forward_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
# update pointers
offset = get_offset_for_next_block(
start_n, kv_indices, kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
)
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
offs_n = offs_n + offset
return acc, l_i, m_i
@triton.jit
def get_offset_for_next_block(
loop_iter, col_indices, total_blocks,
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
):
if BLOCKS_ARE_CONTIGUOUS:
return BLOCK
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
return offset
@triton.jit
def forward_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0,
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'tf32'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.125
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = False
QK_HEAD_DIM : tl.constexpr = 64
QK_HEAD_DIM_ROUNDED : tl.constexpr = 64
V_HEAD_DIM : tl.constexpr = 64
V_HEAD_DIM_ROUNDED : tl.constexpr = 64
SAFE_HEAD_DIM : tl.constexpr = True
BLOCK_M : tl.constexpr = 128
BLOCK_N : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 1073741824
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 1073741824
ALLOW_TF32 : tl.constexpr = False
# -- load k --
k = load_checked_block(K_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
# -- compute qk ---
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
# which is larger than the actual number of elements. To avoid access memory out of bound,
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
tmp0 = (qk)
tmp1 = tmp0.to(tl.float32)
tmp2 = (n)
tmp3 = (m)
tmp4 = tmp2 - tmp3
tmp5 = tmp4.to(tl.float32)
tmp6 = (off_h)
tmp7 = tl.full([1], 1, tl.int32)
tmp8 = tmp6 + tmp7
tmp9 = tmp8.to(tl.float32)
tmp10 = 8.0
tmp11 = tmp9 * tmp10
tmp12 = 0.125
tmp13 = tmp11 * tmp12
tmp14 = -tmp13
tmp15 = libdevice.exp2(tmp14)
tmp16 = tmp5 * tmp15
tmp17 = tmp1 + tmp16
post_mod_scores = tmp17
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
tmp18 = tl.full([1], True, tl.int1)
mask_mod_output = tmp18
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
# apply mask for partially unmasked blocks
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# -- compute scaling constant ---
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
if not ROWS_GUARANTEED_SAFE:
masked_out_rows = (m_ij == float("-inf"))
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
else:
m_ij_masked = m_ij
alpha = tl.math.exp2(m_i - m_ij_masked)
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
# NB: l_i update is pulled up here since it's a bit faster
# NB: For headdim=256, it's faster to move it back down to after m_i =
# m_ij
l_i = l_i * alpha + tl.sum(p, 1)
# # -- scale and update acc --
acc = acc * alpha[:, None]
v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
# -- update m_i
m_i = m_ij
return acc, l_i, m_i
@triton.jit
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
if IS_DIVISIBLE and SAFE_HEAD_DIM:
return tl.load(block_ptr)
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
else:
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
@triton.jit
def get_bounded_indices(indices, max_len=None):
return indices % max_len if max_len is not None else indices
''', device_str='xpu')
meta0 = {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.125, 'GQA_SHARED_HEADS': 1, 'HAS_FULL_BLOCKS': False, 'QK_HEAD_DIM': 64, 'QK_HEAD_DIM_ROUNDED': 64, 'V_HEAD_DIM': 64, 'V_HEAD_DIM_ROUNDED': 64, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 1073741824, 'SPARSE_KV_BLOCK_SIZE': 1073741824, 'ALLOW_TF32': 'False'}
async_compile.wait(globals())
del async_compile
def call(args):
primals_1, primals_2, primals_3, buf0, primals_4, primals_5, buf1, buf2, buf3 = args
with torch.xpu._DeviceGuard(0):
torch.xpu.set_device(0)
stream0 = get_raw_stream(0)
print("primals_1.shape:", primals_1.shape)
print("primals_1.stride:", primals_1.stride())
print("primals_2.shape:", primals_2.shape)
print("primals_2.stride:", primals_2.stride())
print("primals_3.shape:", primals_3.shape)
print("primals_3.stride:", primals_3.stride())
triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, primals_4, primals_5, buf1, buf2, buf3, grid=torch._inductor.kernel.flex_attention.flex_attention_grid(4, 8, 1024, 64, meta0), stream=stream0)
print("buf3:", buf3)
return (buf3,)
if __name__ == "__main__":
file_path = "repro_tensors.pt"
loaded_tensor_tuple = torch.load(file_path)
primals_1, primals_2, primals_3, buf0, primals_4, primals_5, buf1, buf2, buf3 = loaded_tensor_tuple
args = loaded_tensor_tuple
call(args)
Environment details
PyTorch commit: 520079b986b76dc42e7fec3f992da3b1771e3192
Intel Triton commit: b7840ba
Public Triton commit:4b3bb1f8