@@ -189,20 +189,23 @@ def forward(
189
189
190
190
# Assemble into a number of query groups to support MHA, MQA and GQA.
191
191
q_per_kv = self .config .num_heads // self .config .num_query_groups
192
- total_qkv = q_per_kv + 2 # Each group has >=1 queries, 1 key, and 1 value.
192
+ # Each group has >=1 queries, 1 key, and 1 value.
193
193
if self .config .qkv_transpose_before_split :
194
- qkv = qkv .view (
195
- B , T , total_qkv , self .config .num_query_groups , self .head_dim
196
- ) # (B, T, total_qkv, num_query_groups, head_dim)
197
- qkv_axis = - 3
194
+ qkv = qkv .view (B , T , - 1 , self .head_dim )
195
+ q , k , v = qkv .split (
196
+ (
197
+ q_per_kv * self .config .num_query_groups ,
198
+ self .config .num_query_groups ,
199
+ self .config .num_query_groups ,
200
+ ),
201
+ dim = - 2 ,
202
+ )
198
203
else :
199
- qkv = qkv .view (
200
- B , T , self . config . num_query_groups , total_qkv , self . head_dim
201
- ) # (B, T, num_query_groups, total_qkv, head_dim)
202
- qkv_axis = - 2
204
+ qkv = qkv .view (B , T , self . config . num_query_groups , - 1 )
205
+ q , k , v = qkv . split (
206
+ ( q_per_kv * self . head_dim , self . head_dim , self . head_dim ), dim = - 1
207
+ )
203
208
204
- # Split batched computation into three.
205
- q , k , v = qkv .split ((q_per_kv , 1 , 1 ), dim = qkv_axis )
206
209
q = q .reshape (B , T , - 1 , self .head_dim )
207
210
k = k .reshape (B , T , - 1 , self .head_dim )
208
211
v = v .reshape (B , T , - 1 , self .head_dim )
0 commit comments