File tree Expand file tree Collapse file tree 1 file changed +8
-4
lines changed
ai_edge_torch/generative/utilities Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -228,24 +228,28 @@ def _map_attention(
228
228
q_name = self ._names .attn_query_proj .format (idx )
229
229
k_name = self ._names .attn_key_proj .format (idx )
230
230
v_name = self ._names .attn_value_proj .format (idx )
231
- converted_state [f"{ prefix } .atten_func.attn .weight" ] = self ._fuse_qkv (
231
+ converted_state [f"{ prefix } .atten_func.qkv_projection .weight" ] = self ._fuse_qkv (
232
232
config ,
233
233
state .pop (f"{ q_name } .weight" ),
234
234
state .pop (f"{ k_name } .weight" ),
235
235
state .pop (f"{ v_name } .weight" ),
236
236
)
237
237
if config .attn_config .qkv_use_bias :
238
- converted_state [f"{ prefix } .atten_func.attn .bias" ] = self ._fuse_qkv (
238
+ converted_state [f"{ prefix } .atten_func.qkv_projection .bias" ] = self ._fuse_qkv (
239
239
config ,
240
240
state .pop (f"{ q_name } .bias" ),
241
241
state .pop (f"{ k_name } .bias" ),
242
242
state .pop (f"{ v_name } .bias" ),
243
243
)
244
244
245
245
o_name = self ._names .attn_output_proj .format (idx )
246
- converted_state [f"{ prefix } .atten_func.proj.weight" ] = state .pop (f"{ o_name } .weight" )
246
+ converted_state [f"{ prefix } .atten_func.output_projection.weight" ] = state .pop (
247
+ f"{ o_name } .weight"
248
+ )
247
249
if config .attn_config .output_proj_use_bias :
248
- converted_state [f"{ prefix } .atten_func.proj.bias" ] = state .pop (f"{ o_name } .bias" )
250
+ converted_state [f"{ prefix } .atten_func.output_projection.bias" ] = state .pop (
251
+ f"{ o_name } .bias"
252
+ )
249
253
250
254
def _map_norm (
251
255
self ,
You can’t perform that action at this time.
0 commit comments