Skip to content

Commit 8fc74ba

Browse files
committed
Make state dictionary compatible with previous checkpoints
1 parent 8dca199 commit 8fc74ba

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

torchmdnet/models/model.py

+7
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,13 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
233233
model.prior_model[-1].enable = True
234234

235235
state_dict = {re.sub(r"^model\.", "", k): v for k, v in ckpt["state_dict"].items()}
236+
# In ET, before we had output_model.output_network.{0,1}.update_net.[0-9].{weight,bias}
237+
# Now we have output_model.output_network.{0,1}.update_net.layers.[0-9].{weight,bias}
238+
# This change was introduced in https://github.com/torchmd/torchmd-net/pull/314
239+
state_dict = {
240+
re.sub(r"update_net\.(\d+)\.", r"update_net.layers.\1.", k): v
241+
for k, v in state_dict.items()
242+
}
236243
model.load_state_dict(state_dict)
237244
return model.to(device)
238245

0 commit comments

Comments
 (0)