Skip to content

Commit 552e0ee

Browse files
authored
Fix old model loading (#318)
1 parent 0ed2e7c commit 552e0ee

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

torchmdnet/models/model.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,22 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
235235
state_dict = {re.sub(r"^model\.", "", k): v for k, v in ckpt["state_dict"].items()}
236236
# In ET, before we had output_model.output_network.{0,1}.update_net.[0-9].{weight,bias}
237237
# Now we have output_model.output_network.{0,1}.update_net.layers.[0-9].{weight,bias}
238+
# In other models, we had output_model.output_network.{0,1}.{weight,bias},
239+
# which is now output_model.output_network.layers.{0,1}.{weight,bias}
238240
# This change was introduced in https://github.com/torchmd/torchmd-net/pull/314
239-
state_dict = {
240-
re.sub(r"update_net\.(\d+)\.", r"update_net.layers.\1.", k): v
241-
for k, v in state_dict.items()
242-
}
241+
patterns = [
242+
(
243+
r"output_model.output_network.(\d+).update_net.(\d+).",
244+
r"output_model.output_network.\1.update_net.layers.\2.",
245+
),
246+
(
247+
r"output_model.output_network.([02]).(weight|bias)",
248+
r"output_model.output_network.layers.\1.\2",
249+
),
250+
]
251+
for p in patterns:
252+
state_dict = {re.sub(p[0], p[1], k): v for k, v in state_dict.items()}
253+
243254
model.load_state_dict(state_dict)
244255
return model.to(device)
245256

0 commit comments

Comments
 (0)