Skip to content

Commit e51c76e

Browse files
committed
fix attention
1 parent 4bd27e5 commit e51c76e

File tree

3 files changed

+7
-10
lines changed

3 files changed

+7
-10
lines changed

dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,9 @@ def infer_result(
295295
mask,
296296
q_head_num,
297297
kv_head_num,
298+
scale,
298299
head_size,
299300
head_size_v,
300-
scale,
301301
):
302302
return query.new_empty((query.shape[0], q_head_num, head_size_v))
303303

dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,9 @@ def SelfAttentionPAEncoder(
355355
mask,
356356
q_head_num,
357357
kv_head_num,
358+
scale,
358359
head_size,
359360
head_size_v,
360-
scale,
361361
):
362362
op = Operation(name, "SelfAttentionOperation")
363363
param = infer_param.SelfAttentionParam()

dlinfer/graph/dicp/vendor/AtbGraph/conversion.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -478,12 +478,11 @@ def prefill_attention(
478478
if softmax_scale
479479
else 1.0 / math.sqrt(query.node.meta["val"].shape[-1])
480480
)
481+
_, num_q_heads, head_size = query.node.meta["val"].shape
482+
_, num_kv_heads, head_size_v = value.node.meta["val"].shape
481483
if query.node.meta["val"].dtype != mask.node.meta["val"].dtype:
482484
mask = self.get_proxy(atb_op.Cast, (mask, query.node.meta["val"].dtype))
483485
if is_unpaged_prefill:
484-
_, num_q_heads, head_size = query.node.meta["val"].shape
485-
_, num_kv_heads, head_size_v = value.node.meta["val"].shape
486-
487486
out = self.get_proxy(
488487
atb_op.SelfAttentionPAEncoder,
489488
(
@@ -494,18 +493,14 @@ def prefill_attention(
494493
mask,
495494
num_q_heads,
496495
num_kv_heads,
496+
scale,
497497
head_size,
498498
head_size_v,
499-
scale,
500499
),
501500
)
502501
else:
503-
q_shape = list(query.node.meta["val"].shape)
504502
k_cache_shape = list(k_cache.node.meta["val"].shape)
505-
k_shape = list(key.node.meta["val"].shape)
506503
v_cache_shape = list(v_cache.node.meta["val"].shape)
507-
num_q_heads = q_shape[-2]
508-
num_kv_heads = k_shape[-2]
509504

510505
is_kv_require_reshape = len(k_cache_shape) == 3 or len(v_cache_shape) == 3
511506
if is_kv_require_reshape:
@@ -529,6 +524,8 @@ def prefill_attention(
529524
num_q_heads,
530525
num_kv_heads,
531526
scale,
527+
head_size,
528+
head_size_v,
532529
),
533530
)
534531
# graph = self.get_proxy(atb_op.Graph, (out,), {"output": [out]})

0 commit comments

Comments
 (0)