Skip to content

Commit d82312d

Browse files
authored
Refactor rope (#199)
1 parent d8019aa commit d82312d

File tree

5 files changed

+8
-37
lines changed

5 files changed

+8
-37
lines changed

dlinfer/graph/dicp/vendor/AtbGraph/conversion.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def npu_rms_norm_w8a8(self, x, w, eps=1e-6, quant_dtype=torch.int8):
133133
)
134134
return rms_norm_w8a8
135135

136-
@register_conversion("torch.ops.lmdeploy.apply_rotary_pos_emb.default")
137-
def apply_rotary_pos_emb(self, q, k, cos, sin, q_out, k_out):
136+
@register_conversion("torch.ops.dlinfer.apply_rotary_pos_emb.default")
137+
def apply_rotary_pos_emb(self, q, k, cos, sin):
138138
q_shape = list(q.node.meta["val"].shape)
139139
k_shape = list(k.node.meta["val"].shape)
140140
is_qk_require_reshape = len(q_shape) == 3
@@ -151,22 +151,6 @@ def apply_rotary_pos_emb(self, q, k, cos, sin, q_out, k_out):
151151
else self.get_proxy(atb_op.View, (k, (-1, k_shape[1] * k_shape[2])))
152152
)
153153
out = self.get_proxy(atb_op.Rope, (new_q, new_k, cos, sin, None))
154-
if is_qk_require_reshape:
155-
out_q = self.get_proxy(atb_op.GetItem, (out, 0))
156-
out_q = self.get_proxy(atb_op.View, (out_q, (-1, q_shape[1], q_shape[2])))
157-
out_k = self.get_proxy(atb_op.GetItem, (out, 1))
158-
out_k = self.get_proxy(atb_op.View, (out_k, (-1, k_shape[1], k_shape[2])))
159-
out = self.get_proxy(atb_op.Tuple, (out_q, out_k))
160-
if (q_out is not None) and (k_out is not None):
161-
self.get_proxy(
162-
atb_op.AclNnInplaceCopy,
163-
(q_out, self.get_proxy(atb_op.GetItem, (out, 0))),
164-
)
165-
self.get_proxy(
166-
atb_op.AclNnInplaceCopy,
167-
(k_out, self.get_proxy(atb_op.GetItem, (out, 1))),
168-
)
169-
out = self.get_proxy(atb_op.Tuple, (q_out, k_out))
170154
return out
171155

172156
@register_conversion("torch.ops.atb.inplace_div.default")

dlinfer/ops/llm.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ def apply_rotary_pos_emb(
5959
key: Tensor,
6060
cos: Optional[Tensor],
6161
sin: Optional[Tensor],
62-
position_ids: Optional[Tensor],
63-
cos_sin_cache: Optional[Tensor],
6462
) -> Tuple[Tensor, Tensor]:
6563
"""
6664
Apply rotary position embeddings to the query and key tensors.
@@ -73,13 +71,6 @@ def apply_rotary_pos_emb(
7371
key (Tensor): The key tensor to apply the rotary position embeddings to.
7472
cos (Optional[Tensor]): The cosine component of the rotary position embeddings.
7573
sin (Optional[Tensor]): The sine component of the rotary position embeddings.
76-
position_ids (Optional[Tensor]): The position ids used to look up the rotary position embeddings.
77-
cos_sin_cache (Optional[Tensor]): A cache of pre-computed cosine and sine values.
78-
79-
Note:
80-
The parameter groups are mutually exclusive:
81-
- If `cos` and `sin` are both `None`, then `position_ids` and `cos_sin_cache` must both be Tensor.
82-
- If `position_ids` and `cos_sin_cache` are both `None`, then `cos` and `sin` must both be Tensor.
8374
8475
Returns:
8576
Tuple[Tensor, Tensor]:
@@ -91,8 +82,6 @@ def apply_rotary_pos_emb(
9182
key,
9283
cos,
9384
sin,
94-
position_ids,
95-
cos_sin_cache,
9685
)
9786

9887

dlinfer/vendor/ascend/torch_npu_ops.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,14 @@ def apply_rotary_pos_emb(
4141
key: Tensor,
4242
cos: Optional[Tensor],
4343
sin: Optional[Tensor],
44-
position_ids: Optional[Tensor],
45-
cos_sin_cache: Optional[Tensor],
4644
) -> Tuple[Tensor, Tensor]:
4745
# rotary pos emb helpers:
46+
query = query.contiguous().unsqueeze(0)
47+
key = key.contiguous().unsqueeze(0)
4848
assert len(query.shape) == 4
4949
batch, seq_len, _, _ = query.shape
5050
cos = cos.reshape(batch, seq_len, 1, -1)
5151
sin = sin.reshape(batch, seq_len, 1, -1)
52-
query = query.contiguous()
53-
key = key.contiguous()
5452

5553
def rotate_half_(x):
5654
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]

dlinfer/vendor/camb/camb_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ def apply_rotary_pos_emb(
8686
key: Tensor,
8787
cos: Optional[Tensor], # (total_seq_len, head_dim)
8888
sin: Optional[Tensor],
89-
position_ids: Optional[Tensor],
90-
cos_sin_cache: Optional[Tensor],
9189
) -> Tuple[Tensor, Tensor]:
90+
query = query.contiguous().unsqueeze(0)
91+
key = key.contiguous().unsqueeze(0)
9292
interleaved = False # False for fold rope, True for cross rope
9393
# [1, total_seq_len, q_head_num, head_dim]
9494
_, total_seq_len, _, head_dim = query.shape

dlinfer/vendor/maca/maca_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ def apply_rotary_pos_emb(
7979
key: Tensor,
8080
cos: Optional[Tensor],
8181
sin: Optional[Tensor],
82-
position_ids: Optional[Tensor],
83-
cos_sin_cache: Optional[Tensor],
8482
) -> Tuple[Tensor, Tensor]:
83+
query = query.contiguous().unsqueeze(0)
84+
key = key.contiguous().unsqueeze(0)
8585
position_ids_1d = torch.arange(0, query.size(1), device=query.device)
8686
head_size = query.size(-1)
8787
query = query.flatten(-2, -1)

0 commit comments

Comments
 (0)