Skip to content

Commit 19a168c

Browse files
fix the tiny llama conversion issue (#7)
* fix the tiny llama conversion issue * fix formatting. --------- Co-authored-by: Advait Jain <[email protected]>
1 parent 2d4e18e commit 19a168c

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

ai_edge_torch/generative/utilities/loader.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,24 +228,28 @@ def _map_attention(
228228
q_name = self._names.attn_query_proj.format(idx)
229229
k_name = self._names.attn_key_proj.format(idx)
230230
v_name = self._names.attn_value_proj.format(idx)
231-
converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
231+
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv(
232232
config,
233233
state.pop(f"{q_name}.weight"),
234234
state.pop(f"{k_name}.weight"),
235235
state.pop(f"{v_name}.weight"),
236236
)
237237
if config.attn_config.qkv_use_bias:
238-
converted_state[f"{prefix}.atten_func.attn.bias"] = self._fuse_qkv(
238+
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
239239
config,
240240
state.pop(f"{q_name}.bias"),
241241
state.pop(f"{k_name}.bias"),
242242
state.pop(f"{v_name}.bias"),
243243
)
244244

245245
o_name = self._names.attn_output_proj.format(idx)
246-
converted_state[f"{prefix}.atten_func.proj.weight"] = state.pop(f"{o_name}.weight")
246+
converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
247+
f"{o_name}.weight"
248+
)
247249
if config.attn_config.output_proj_use_bias:
248-
converted_state[f"{prefix}.atten_func.proj.bias"] = state.pop(f"{o_name}.bias")
250+
converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
251+
f"{o_name}.bias"
252+
)
249253

250254
def _map_norm(
251255
self,

0 commit comments

Comments
 (0)