@@ -235,11 +235,22 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
235
235
state_dict = {re .sub (r"^model\." , "" , k ): v for k , v in ckpt ["state_dict" ].items ()}
236
236
# In ET, before we had output_model.output_network.{0,1}.update_net.[0-9].{weight,bias}
237
237
# Now we have output_model.output_network.{0,1}.update_net.layers.[0-9].{weight,bias}
238
+ # In other models, we had output_model.output_network.{0,1}.{weight,bias},
239
+ # which is now output_model.output_network.layers.{0,1}.{weight,bias}
238
240
# This change was introduced in https://github.com/torchmd/torchmd-net/pull/314
239
- state_dict = {
240
- re .sub (r"update_net\.(\d+)\." , r"update_net.layers.\1." , k ): v
241
- for k , v in state_dict .items ()
242
- }
241
+ patterns = [
242
+ (
243
+ r"output_model.output_network.(\d+).update_net.(\d+)." ,
244
+ r"output_model.output_network.\1.update_net.layers.\2." ,
245
+ ),
246
+ (
247
+ r"output_model.output_network.([02]).(weight|bias)" ,
248
+ r"output_model.output_network.layers.\1.\2" ,
249
+ ),
250
+ ]
251
+ for p in patterns :
252
+ state_dict = {re .sub (p [0 ], p [1 ], k ): v for k , v in state_dict .items ()}
253
+
243
254
model .load_state_dict (state_dict )
244
255
return model .to (device )
245
256
0 commit comments